From 6e7c3dc811d8c823577a3d3d880d35e566a49b4f Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Wed, 13 Aug 2025 10:03:38 +0800 Subject: [PATCH 01/11] l2 demo --- .../msprobe/pytorch/monitor/features.py | 110 +++++++++++++ .../msprobe/pytorch/monitor/module_hook.py | 155 +++++++++++++++++- .../msprobe/pytorch/monitor/module_metric.py | 25 +++ 3 files changed, 284 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py index cfd0c1615..237adcd07 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py @@ -111,3 +111,113 @@ def cal_histc(tensor_cal, bins_total, min_val, max_val): @torch.no_grad() def get_nans(t): return torch.isnan(t).sum() + + +def check_tensor_dim(tensor, n): + """检查张量维度是否大于n + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Input must be a PyTorch tensor. Got {type(tensor)} instead. " + f"Consider using torch.tensor() for conversion." + ) + + if tensor.dim() < n: + raise ValueError( + f"Tensor must have at least {n} dimensions. " + f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims." + ) + + +@torch.no_grad() +def max_eigenvalue(input_tensor: torch.Tensor, num_iterations=3): + input_tensor = input_tensor.float() + try: + check_tensor_dim(input_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate max eigenvalue failed: {e}") + return torch.tensor(0) + in_features = input_tensor.shape[1] + u_tensor = torch.randn(in_features).to(input_tensor.device) + u_norm = u_tensor.norm() + if u_norm.item() == 0: + return torch.tensor(0) + u_tensor = u_tensor / u_tensor.norm() + input_seq = torch.matmul(input_tensor.T, input_tensor) + for _ in range(num_iterations): + v_tensor = torch.matmul(input_seq, u_tensor) + spectral_norm = torch.matmul(v_tensor.T, u_tensor) + v_norm = v_tensor.norm() + if v_norm > 0: + u_tensor = v_tensor / v_norm + else: + spectral_norm = torch.tensor(0) + break + return spectral_norm.sqrt() + + +@torch.no_grad() +def cal_entropy(qk_tensor, mask=None): + try: + check_tensor_dim(qk_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate max eigenvalue failed: {e}") + return torch.tensor(0), torch.tensor(0) + if mask is None: + mask = torch.tril(torch.ones(qk_tensor.shape[1], qk_tensor.shape[1])).to( + qk_tensor.device) + qk_tensor = qk_tensor - torch.amax(qk_tensor, dim=1, keepdim=True) + qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf')) + softmax_qkt = torch.nn.functional.softmax(qk_tensor.float(), dim=1) + # softmax取QK矩阵最大值 + softmax_max = torch.max(torch.amax(softmax_qkt, dim=1)) + entropy = torch.mean(-torch.nansum(softmax_qkt * + torch.log(softmax_qkt), dim=1)) + return entropy, softmax_max + + +@torch.no_grad() +def cal_qkt(q_h, k_h, order="s,b,h,d"): + # q_h shape is [s, b, h, d] + try: + check_tensor_dim(q_h, 4) + check_tensor_dim(k_h, 4) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate qk tensor failed: {e}") + return torch.tensor(0) + + if order == "s,b,h,d": + qkt = torch.matmul( + q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5 + elif order == "b,s,h,d": + qkt = torch.matmul( + q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5 + else: + logger.warning("Calculate qk tensor failed: Order unsupported.") + qkt = torch.tensor(0) + return qkt + + +@torch.no_grad() +def cal_stable_rank(weight: torch.Tensor): + eig = max_eigenvalue(weight) + if eig == torch.tensor(0): + return torch.tensor(0), torch.tensor(0) + f_norm = torch.norm(weight, p="fro") + return f_norm / eig, eig + + +@torch.no_grad() +def cal_svd_entropy(weight: torch.Tensor, k=50): + epsilon = 1e-10 + if isinstance(weight, torch.Tensor): + _, s, _ = torch.svd_lowrank(weight.float(), q=k) + s_sum = torch.sum(s) + if s_sum.item() == 0: + return torch.tensor(0) + p = s / torch.sum(s) + entropy = -torch.sum(p * torch.log2(p + epsilon)) + else: + logger.warning("Calculate SVD entropy failed: Weight is not a tensor") + entropy = torch.tensor(0) + return entropy diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index f39cd7a83..700c67c6c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -40,9 +40,9 @@ from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ get_process_group -from msprobe.pytorch.monitor.features import get_sign_matches +from msprobe.pytorch.monitor.features import get_sign_matches, cal_qkt from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \ - TensorMetrics, squash_param_name + TensorMetrics, squash_param_name, get_entropy_metric, get_sr_metric from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer @@ -57,7 +57,7 @@ FORMAT_MAPPING = { MonitorConst.CSV: CSVWriterWithAD, MonitorConst.API: BaseWriterWithAD } - +start_step = 0 def param_is_not_tensor_parallel_duplicate(param, tp_group): return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( @@ -83,7 +83,17 @@ class ModuleHookContext: self.actvgrad.clear() -start_step = 0 +class FeatureHookContext: + def __init__(self, module_name): + self.step = 0 + self.micro_step = 0 + self.attention_feature = {} + self.linear_feature = {} + self.module_name = module_name + + def reset(self): + self.attention_feature.clear() + self.linear_feature.clear() class OptimizerContext: @@ -206,6 +216,7 @@ class TrainerMon: # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量 self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.feature_hook_context_by_module = defaultdict(FeatureHookContext) self.optimizer_context = defaultdict(OptimizerContext) self.cc_context = defaultdict(CommunicationContext) self.grad_context = GradContext() @@ -298,6 +309,8 @@ class TrainerMon: self.cc_distribution = self.config.get("cc_distribution", {}) self.stack_info = self.config.get('stack_info', False) self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False) + self.recording_l2_features = self.config.get("recording_l2_features", False) + self.sa_order = self.config.get("sa_order", "s,b,h,d") if not self.cc_distribution.get('enable', False): self.cc_log_only = False @@ -356,6 +369,8 @@ class TrainerMon: logger.info_on_rank_0("> momentum and variance of adam is not monitored. ") if not self.wg_distribution: logger.info_on_rank_0("> weight grad of specified module is not monitored. ") + if not self.recording_l2_features: + logger.info_on_rank_0("> l2 features of specified module is not monitored. ") if not self.mg_direction: logger.info_on_rank_0('> grad and momentum direction will not be compared.') if not self.cc_distribution.get('enable', False): @@ -537,6 +552,27 @@ class TrainerMon: if self.grad_context.actv: self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) + def write_metrics_if_not_empty(self, features, metrics, step, hook_name): + if len(features) == 0: + return + if hook_name in ["linear_hook"]: + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=False) + else: + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=True) + features.clear() + + def write_features_tb(self, step): + if not self.recording_l2_features: + return + for context in self.feature_hook_context_by_module.values(): + num_features = len(context.attention_feature) + len(context.linear_feature) + len( + context.token_feature) + len(context.norm_feature) + if num_features == 0: + continue + self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax_max"], + step, "attention_hook") + self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step, "linear_hook") + def write_param_tb(self, opt_context): if not self.param_distribution: return @@ -691,6 +727,7 @@ class TrainerMon: if self.anomaly_data_factory: self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) + self.write_features_tb(context.step) self.write_grad_tb(context.step) self.write_mv_tb(context) self.write_param_tb(context) @@ -760,7 +797,8 @@ class TrainerMon: vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}' targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[ 'targets'].keys() - hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + l2_target_names = self.config.get('l2_targets', '') + hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage) logger.info_on_rank_0(f"> {hooked_count} modules are monitored.") @@ -801,6 +839,9 @@ class TrainerMon: for handle in self.handles['xy']: handle.remove() self.handles['xy'].clear() + for handle in self.handles['L2_features']: + handle.remove() + self.handles['L2_features'].clear() # 清空对应context缓存 for _, fwd_context in self.module_fwd_hook_context_by_module.items(): fwd_context.reset() @@ -941,7 +982,38 @@ class TrainerMon: return pattern return "" - def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''): + def _is_recording_module(self, module_name, l2_targets, vpp_stage, hook_name): + + if len(l2_targets) > 0: + for pattern in [ + vpp_stage + squash_param_name(module_name, self.squash_name), + vpp_stage + module_name, + ]: + if pattern in l2_targets: + return pattern + elif hook_name in ["linear_hook", "norm_hook"]: + return vpp_stage + squash_param_name(module_name, self.squash_name) + return "" + + def _get_linear_hook_target(self, module): + if isinstance(module, torch.nn.Embedding): + return '' + if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"): + return '' + for weight_name in ["weight", "wg"]: + if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): + if getattr(module, weight_name) == 2: + return weight_name + return '' + + def _get_norm_hook_target(self, module): + for weight_name in ["weight"]: + if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): + if getattr(module, weight_name) == 1: + return weight_name + return '' + + def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''): if '_modules' not in module.__dict__: # nothing to hook return 0 @@ -1020,6 +1092,56 @@ class TrainerMon: context.micro_step = 0 return + def extract_attention_feature_hook(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + tbtag_tensor_map = {} + if len(module_input) < 2: + raise ValueError("the length of module_input in attention hook's module " + "should be greater than or equal to 2.") + q_h = module_input[0] + k_h = module_input[1] + qkt = cal_qkt(q_h, k_h, order=self.sa_order) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.attention', + '', 'qkt', qkt) + ) + get_entropy_metric(tbtag_tensor_map, context.attention_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def extract_linear_sr_hook(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + weight_name = self._get_linear_hook_target(module) + if weight_name == '': + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + if context.micro_step == self.micro_batch_number: + tbtag_tensor_map = {} + value = getattr(module, weight_name).data + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.linear', + '', 'sr', value) + ) + get_sr_metric(tbtag_tensor_map, context.linear_feature) + context.micro_step = 0 + context.step += 1 + context.micro_step += 1 + return + def stack_hook(module, args, kwargs, module_output, name): if module not in self.module_fwd_hook_context_by_module: self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) @@ -1051,6 +1173,27 @@ class TrainerMon: self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) logger.info_on_rank_0(f"> {name} is monitored successfully") hooked_count += 1 + if not self.print_struct and self.recording_l2_features: + for module_name, submodule in module.named_modules(): + func_map = { + "attention_hook": extract_attention_feature_hook, + "linear_hook": extract_linear_sr_hook, + } + hooks = ["attention_hook", "linear_hook"] + for hook_name in hooks: + if hook_name not in l2_target_names: + continue + temp_names = l2_target_names[hook_name] + name = self._is_recording_module(module_name, temp_names, vpp_stage, hook_name) + if name: + handle = submodule.register_forward_hook(partial(func_map[hook_name], name=name)) + print_feature_name = hook_name.split('_')[0] + logger.info_on_rank_0( + f'> {print_feature_name} features of {name} is monitored successfully') + self.handles["L2_features"].append(handle) + hooked_count += 1 + continue + return hooked_count def _patch_grad_sync(self): diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index c5730d784..24e9b43d6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -17,6 +17,7 @@ import re import torch from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean +from msprobe.pytorch.monitor.features import cal_entropy, cal_stable_rank from msprobe.pytorch.monitor.utils import get_nan_tensor @@ -185,3 +186,27 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None): fun_metric = config_metric_registry.get(metric_name) out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps) return out_dict + + +def get_sr_metric(tag2tensor, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if "sr" not in tag: + continue + if tag not in out_dict: + out_dict[tag] = {} + sr, eig = cal_stable_rank(tensor) + out_dict[tag]['sr'] = sr + out_dict[tag]['kernel_norm'] = eig + + +def get_entropy_metric(tag2tensor, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if tag not in out_dict: + out_dict[tag] = {} + entropy, softmax_max = cal_entropy(tensor) + out_dict[tag]['entropy'] = entropy + out_dict[tag]['softmax_max'] = softmax_max -- Gitee From b6fe5c7450575b5d6dc1ee3636aadb0d0edef1a4 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Wed, 13 Aug 2025 14:17:32 +0800 Subject: [PATCH 02/11] dim bugfix --- debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 700c67c6c..ae38fe6ad 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -1002,14 +1002,14 @@ class TrainerMon: return '' for weight_name in ["weight", "wg"]: if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): - if getattr(module, weight_name) == 2: + if getattr(module, weight_name).dim() == 2: return weight_name return '' def _get_norm_hook_target(self, module): for weight_name in ["weight"]: if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): - if getattr(module, weight_name) == 1: + if getattr(module, weight_name).dim() == 1: return weight_name return '' -- Gitee From f19bbbdb75037bff7de5010e0d53e1a705537042 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 15 Aug 2025 09:43:31 +0800 Subject: [PATCH 03/11] add readme --- .../accuracy_tools/msprobe/docs/19.monitor.md | 30 +++++++++++++++++++ .../msprobe/pytorch/monitor/module_hook.py | 5 ++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 3d4be725d..de1e9374e 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -24,6 +24,7 @@ | [采集module堆栈信息](#采集module堆栈信息) | 采集监控的第一个 step 的 module 对应的堆栈信息辅助问题定位 | PyTorch、MindSpore | | [指定监控对象](#指定监控对象) | 指定监控的nn.Module(nn.Cell)及对应的输入输出 | PyTorch、MindSpore | | [打印模型结构](#打印模型结构) | 打印模型结构 | PyTorch | +| [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch | | [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`、`ndigits`均支持 | PyTorch、MindSpore | | [mbs粒度梯度监控](#mbs粒度梯度监控) | 开启梯度监控时,采集聚合前梯度时支持`micro_batch_size`粒度 | PyTorch、MindSpore | | [异常告警](#异常告警) | 监控对象指标异常时自动告警,支持异常数据落盘 | PyTorch、MindSpore | @@ -302,6 +303,35 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。 } ``` +### l2可解释特征监控 +- 工具配置 +```json +{ + "l2_targets": { + "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], + "linear_hook": ["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"] + }, + "recording_l2_features": true +} +``` +| 配置项 | 类型 | 说明 | 示例值 | +|--------|------|------|--------| +| **l2_targets.attention_hook** | List[str] | 指定需要监控的注意力层, 采集"entropy"和"sorftmax_max"指标,需要通过[打印模型结构功能](#打印模型结构)获取 | `["0:0.self_attention.core_attention.flash_attention"]` | +| **l2_targets.linear_hook** | List[str] | 指定需要监控的线性层, 采集"sr"和 "kernel_norm"指标,需要通过[打印模型结构功能](#打印模型结构)获取,支持传入空列表自动识别线性模块 | `["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"]` | +| **recording_l2_features** | bool | 是否开启L2层特征数据采集 | `true` | + + +#### L2可解释特征监控指标说明 + +| **指标名称** | **适用Hook类型** | **数学定义/计算方式** | **监控意义** | +|--------------------|-------------------|-------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| +| **entropy** | attention_hook | $H(p)=-\sum p_i \log p_i$,其中$p_i$为注意力权重 | 衡量注意力分布的不确定性,**低熵值**表示注意力集中 | +| **softmax_max** | attention_hook | $\max(\text{softmax}(QK^T/\sqrt{d}))$ | 反映注意力机制的聚焦程度,**高值**表示存在显著主导的注意力token | +| **sr(stable_rank)** | linear_hook | $\frac{\|W\|_F}{\|W\|_2}$(稳定秩,Frobenius范数除以谱范数) | 评估权重矩阵的有效秩,**低值**表示矩阵接近低秩不稳定状态 | +| **kernel_norm** | linear_hook | $\|W\|_F$(Frobenius范数) | 权重矩阵的缩谱范数,反映输入在矩阵最大奇异向量张成空间的放大系数 | + + + ### 输出格式和统计量 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index ae38fe6ad..0f1e3276f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -565,8 +565,7 @@ class TrainerMon: if not self.recording_l2_features: return for context in self.feature_hook_context_by_module.values(): - num_features = len(context.attention_feature) + len(context.linear_feature) + len( - context.token_feature) + len(context.norm_feature) + num_features = len(context.attention_feature) + len(context.linear_feature) if num_features == 0: continue self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax_max"], @@ -991,7 +990,7 @@ class TrainerMon: ]: if pattern in l2_targets: return pattern - elif hook_name in ["linear_hook", "norm_hook"]: + elif hook_name in ["linear_hook"]: return vpp_stage + squash_param_name(module_name, self.squash_name) return "" -- Gitee From 5e404d306a0589ef96a7cf9605d268bcdb0ae024 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 15 Aug 2025 10:55:46 +0800 Subject: [PATCH 04/11] add ut --- .../msprobe/core/common/const.py | 1 + .../msprobe/core/monitor/utils.py | 24 +++++++ .../msprobe/pytorch/monitor/features.py | 16 ----- .../test/pytorch_ut/monitor/test_features.py | 72 ++++++++++++++++++- .../pytorch_ut/monitor/test_monitor_utils.py | 52 +++++++++++++- 5 files changed, 147 insertions(+), 18 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 9c659a22d..ad18ac709 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -797,6 +797,7 @@ class MonitorConst: ) DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer" RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan'] + L2_HOOKS = ["linear_hook", "attention_hook"] SLICE_SIZE = 20480 # used for name diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py index f19e14d89..a65e806b8 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -96,7 +96,25 @@ def validate_targets(targets): raise TypeError('key of targets should be module_name[str] in config.json') if not isinstance(field, dict): raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') + +def validate_l2_targets(targets): + if not isinstance(targets, dict): + raise TypeError('l2_targets in config.json should be a dict') + for hook_name, target_list in targets.items(): + if hook_name not in MsgConst.L2_HOOKS: + raise TypeError(f'key of l2_targtes must be in {MsgConst.L2_HOOKS}, got {hook_name}') + if not isinstance(target_list, list): + raise TypeError('values of l2_targets should be a list in config.json') + for item in target_list: + if not isinstance(item, str): + raise TypeError(f'item of "{hook_name}" in l2_targets should be module_name[str] in config.json') + + +def validate_recording_l2_features(recording_l2_features): + if not isinstance(recording_l2_features, bool): + raise TypeError("recording_l2_features should be a bool") + def validate_print_struct(print_struct): if not isinstance(print_struct, bool): @@ -216,6 +234,12 @@ def validate_config(config): targets = config.get("targets", {}) validate_targets(targets) + l2_targets = config.get("l2_targets", {}) + validate_l2_targets(l2_targets) + + recording_l2_features = config.get("recording_l2_features", False) + validate_recording_l2_features(recording_l2_features) + print_struct = config.get('print_struct', False) validate_print_struct(print_struct) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py index 237adcd07..1c163a58d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py @@ -205,19 +205,3 @@ def cal_stable_rank(weight: torch.Tensor): return torch.tensor(0), torch.tensor(0) f_norm = torch.norm(weight, p="fro") return f_norm / eig, eig - - -@torch.no_grad() -def cal_svd_entropy(weight: torch.Tensor, k=50): - epsilon = 1e-10 - if isinstance(weight, torch.Tensor): - _, s, _ = torch.svd_lowrank(weight.float(), q=k) - s_sum = torch.sum(s) - if s_sum.item() == 0: - return torch.tensor(0) - p = s / torch.sum(s) - entropy = -torch.sum(p * torch.log2(p + epsilon)) - else: - logger.warning("Calculate SVD entropy failed: Weight is not a tensor") - entropy = torch.tensor(0) - return entropy diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py index ff00cf749..a4f771ec1 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py @@ -1,8 +1,10 @@ import unittest +from unittest.mock import patch + import torch from msprobe.pytorch.monitor.features import square_sum, get_min, get_mean, get_norm, get_max, get_zeros, \ get_sign_matches, eff_rank, mNTK, lambda_max_subsample, cal_histc, get_nans - +from msprobe.pytorch.monitor.features import max_eigenvalue, cal_entropy, cal_qkt, cal_stable_rank class TestMathFunctions(unittest.TestCase): def test_square_sum(self): @@ -87,6 +89,74 @@ class TestMathFunctions(unittest.TestCase): result = get_nans(tensor) self.assertEqual(result, 1) + def test_max_eigenvalue(self): + """测试最大特征值计算""" + # 创建已知特征值的矩阵 + A = torch.diag(torch.tensor([3.0, 2.0, 1.0])) + + # 测试不同迭代次数 + eigval = max_eigenvalue(A, num_iterations=5) + self.assertAlmostEqual(eigval.item(), 3.0, delta=0.1) + + # 测试全零矩阵 + zero_matrix = torch.zeros(3, 3) + eigval = max_eigenvalue(zero_matrix) + self.assertAlmostEqual(eigval.item(), 0.0) + + def test_cal_entropy(self): + """测试注意力熵计算""" + # 创建简单的注意力分数 + qk = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]) + + # 无mask + entropy, softmax_max = cal_entropy(qk) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + # 带mask 和默认生成相同 + mask = torch.tensor([[1, 0, 0], + [1, 1, 0], + [1, 1, 1]], dtype=torch.float) + entropy, _ = cal_entropy(qk, mask) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + @patch("msprobe.pytorch.monitor.features.logger") + def test_cal_qkt(self, mock_logger): + """测试QK^T计算""" + # 测试s,b,h,d顺序 + q = torch.randn(10, 2, 4, 8) # [s, b, h, d] + k = torch.randn(10, 2, 4, 8) # [s, b, h, d] + q_batch = torch.randn(2, 10, 4, 8) # [b, s, h, d] + qkt = cal_qkt(q, k, order="s,b,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试b,s,h,d顺序 + qkt = cal_qkt(q_batch, q_batch, order="b,s,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试无效顺序 + cal_qkt(q, k, order="invalid_order") + mock_logger.warning.assert_called_with( + "Calculate qk tensor failed: Order unsupported.") + + def test_cal_stable_rank(self): + """测试谱半径计算""" + # 创建已知谱半径的矩阵 + A = torch.diag(torch.tensor([3.0, 2.0, 1.0])) + sr, eig = cal_stable_rank(A) + + # 验证Frobenius范数 + fro_norm = torch.norm(A, p='fro') + self.assertAlmostEqual(sr, fro_norm / 3.0, delta=.5) # 最大特征值为3 + + # 测试正交矩阵 + ortho = torch.eye(5) + sr, eig = cal_stable_rank(ortho) + self.assertAlmostEqual(sr, torch.tensor(2.23/1), delta=.5) # F范数应为2.23 + self.assertAlmostEqual(eig, torch.tensor(1.0), delta=.1) # 特征值应为1 if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 83e8217c8..741f33623 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -8,7 +8,7 @@ from msprobe.core.common.const import MonitorConst from msprobe.core.monitor.utils import filter_special_chars, MsgConst, validate_ops, validate_ranks, \ validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ validate_mg_distribution, validate_wg_distribution, validate_cc_distribution, validate_alert, validate_config, \ - get_output_base_dir + get_output_base_dir, validate_l2_targets, validate_recording_l2_features from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.common.utils import is_recomputation @@ -112,6 +112,56 @@ class TestValidationFunctions(unittest.TestCase): self.assertEqual(config["targets"], {"": {}}) self.assertEqual(config["all_xy"], True) + # ===== validate_l2_targets 测试 ===== + def test_validate_l2_targets_valid_input(self): + """测试合法输入""" + valid_targets = { + "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], + "linear_hook": [] + } + validate_l2_targets(valid_targets) # 不应抛出异常 + + def test_validate_l2_targets_invalid_root_type(self): + """测试非 dict 输入""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets("not_a_dict") + self.assertEqual(str(cm.exception), + 'l2_targets in config.json should be a dict') + + def test_validate_l2_targets_invalid_hook_name(self): + """测试非法 hook_name""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"invalid_hook": ["module1"]}) + self.assertIn(f'key of l2_targtes must be in {MsgConst.L2_HOOKS}', + str(cm.exception)) + + def test_validate_l2_targets_invalid_value_type(self): + """测试非法 value 类型""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"hook1": "not_a_list"}) + self.assertEqual(str(cm.exception), + 'values of l2_targets should be a list in config.json') + + def test_validate_l2_targets_invalid_item_type(self): + """测试非法 list item 类型""" + with self.assertRaises(TypeError) as cm: + validate_l2_targets({"linear_hook": [123]}) + self.assertEqual(str(cm.exception), + 'item of "linear_hook" in l2_targets should be module_name[str] in config.json') + + # ===== validate_recording_l2_features 测试 ===== + def test_validate_recording_l2_features_valid(self): + """测试合法布尔值输入""" + validate_recording_l2_features(True) # 不应抛出异常 + validate_recording_l2_features(False) # 不应抛出异常 + + def test_validate_recording_l2_features_invalid_type(self): + """测试非法类型输入""" + with self.assertRaises(TypeError) as cm: + validate_recording_l2_features("xx") + self.assertEqual(str(cm.exception), + "recording_l2_features should be a bool") + class TestIsRecomputation(unittest.TestCase): @patch('inspect.stack') -- Gitee From 02ef42bbae225d2372c38bafc5831b89d272e40e Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 15 Aug 2025 11:13:32 +0800 Subject: [PATCH 05/11] bug fix --- .../msprobe/core/monitor/utils.py | 4 +- .../msprobe/pytorch/monitor/module_hook.py | 52 ++++++++----------- .../pytorch_ut/monitor/test_monitor_utils.py | 2 +- 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py index a65e806b8..3cdec3274 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -102,8 +102,8 @@ def validate_l2_targets(targets): if not isinstance(targets, dict): raise TypeError('l2_targets in config.json should be a dict') for hook_name, target_list in targets.items(): - if hook_name not in MsgConst.L2_HOOKS: - raise TypeError(f'key of l2_targtes must be in {MsgConst.L2_HOOKS}, got {hook_name}') + if hook_name not in MonitorConst.L2_HOOKS: + raise TypeError(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}, got {hook_name}') if not isinstance(target_list, list): raise TypeError('values of l2_targets should be a list in config.json') for item in target_list: diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 0f1e3276f..f108591cb 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -59,6 +59,7 @@ FORMAT_MAPPING = { } start_step = 0 + def param_is_not_tensor_parallel_duplicate(param, tp_group): return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( torch.distributed.get_rank(group=tp_group) == 0 @@ -285,6 +286,18 @@ class TrainerMon: cc_tensor.reset() return metrics + @staticmethod + def get_linear_hook_target(module): + if isinstance(module, torch.nn.Embedding): + return '' + if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"): + return '' + for weight_name in ["weight", "wg"]: + if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): + if getattr(module, weight_name).dim() == 2: + return weight_name + return '' + def set_config(self): logger.info(f"current config: {self.config}") self.start_step = self.config.get("start_step", 0) @@ -553,13 +566,13 @@ class TrainerMon: self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) def write_metrics_if_not_empty(self, features, metrics, step, hook_name): - if len(features) == 0: - return - if hook_name in ["linear_hook"]: - self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=False) - else: - self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=True) - features.clear() + if len(features) == 0: + return + if hook_name in ["linear_hook"]: + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=False) + else: + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=True) + features.clear() def write_features_tb(self, step): if not self.recording_l2_features: @@ -991,27 +1004,9 @@ class TrainerMon: if pattern in l2_targets: return pattern elif hook_name in ["linear_hook"]: - return vpp_stage + squash_param_name(module_name, self.squash_name) + return vpp_stage + squash_param_name(module_name, self.squash_name) return "" - - def _get_linear_hook_target(self, module): - if isinstance(module, torch.nn.Embedding): - return '' - if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"): - return '' - for weight_name in ["weight", "wg"]: - if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): - if getattr(module, weight_name).dim() == 2: - return weight_name - return '' - def _get_norm_hook_target(self, module): - for weight_name in ["weight"]: - if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor): - if getattr(module, weight_name).dim() == 1: - return weight_name - return '' - def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''): if '_modules' not in module.__dict__: # nothing to hook @@ -1120,7 +1115,7 @@ class TrainerMon: def extract_linear_sr_hook(module, module_input, module_output, name): if is_recomputation() or not module.training: return - weight_name = self._get_linear_hook_target(module) + weight_name = self.get_linear_hook_target(module) if weight_name == '': return @@ -1178,8 +1173,7 @@ class TrainerMon: "attention_hook": extract_attention_feature_hook, "linear_hook": extract_linear_sr_hook, } - hooks = ["attention_hook", "linear_hook"] - for hook_name in hooks: + for hook_name in func_map.keys(): if hook_name not in l2_target_names: continue temp_names = l2_target_names[hook_name] diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 741f33623..7e7e3aeff 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -132,7 +132,7 @@ class TestValidationFunctions(unittest.TestCase): """测试非法 hook_name""" with self.assertRaises(TypeError) as cm: validate_l2_targets({"invalid_hook": ["module1"]}) - self.assertIn(f'key of l2_targtes must be in {MsgConst.L2_HOOKS}', + self.assertIn(f'key of l2_targtes must be in {MonitorConst.L2_HOOKS}', str(cm.exception)) def test_validate_l2_targets_invalid_value_type(self): -- Gitee From f8e1ab617cd6adce2c8890f3a40a3206752a6684 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 15 Aug 2025 16:13:53 +0800 Subject: [PATCH 06/11] bugfix --- debug/accuracy_tools/msprobe/core/common/const.py | 1 + debug/accuracy_tools/msprobe/core/monitor/utils.py | 8 +++++++- debug/accuracy_tools/msprobe/docs/19.monitor.md | 4 +++- .../msprobe/pytorch/monitor/features.py | 2 +- .../test/pytorch_ut/monitor/test_features.py | 2 +- .../test/pytorch_ut/monitor/test_monitor_utils.py | 13 +++++++++++-- 6 files changed, 24 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index ad18ac709..a724c2910 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -798,6 +798,7 @@ class MonitorConst: DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer" RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan'] L2_HOOKS = ["linear_hook", "attention_hook"] + SA_ORDERS = ["s,b,h,d", "b,s,h,d"] SLICE_SIZE = 20480 # used for name diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py index 3cdec3274..658ccb2b1 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -115,7 +115,10 @@ def validate_recording_l2_features(recording_l2_features): if not isinstance(recording_l2_features, bool): raise TypeError("recording_l2_features should be a bool") - +def validate_sa_order(sa_order): + if sa_order not in MonitorConst.SA_ORDERS: + raise TypeError(f'sa_order must be in {MonitorConst.SA_ORDERS}, got {sa_order}') + def validate_print_struct(print_struct): if not isinstance(print_struct, bool): raise TypeError("print_struct should be a bool") @@ -239,6 +242,9 @@ def validate_config(config): recording_l2_features = config.get("recording_l2_features", False) validate_recording_l2_features(recording_l2_features) + + sa_order = config.get("sa_order", False) + validate_sa_order(sa_order) print_struct = config.get('print_struct', False) validate_print_struct(print_struct) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index de1e9374e..96f421371 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -311,7 +311,8 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。 "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], "linear_hook": ["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"] }, - "recording_l2_features": true + "recording_l2_features": true, + "sa_order": "b,s,h,d" } ``` | 配置项 | 类型 | 说明 | 示例值 | @@ -319,6 +320,7 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。 | **l2_targets.attention_hook** | List[str] | 指定需要监控的注意力层, 采集"entropy"和"sorftmax_max"指标,需要通过[打印模型结构功能](#打印模型结构)获取 | `["0:0.self_attention.core_attention.flash_attention"]` | | **l2_targets.linear_hook** | List[str] | 指定需要监控的线性层, 采集"sr"和 "kernel_norm"指标,需要通过[打印模型结构功能](#打印模型结构)获取,支持传入空列表自动识别线性模块 | `["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"]` | | **recording_l2_features** | bool | 是否开启L2层特征数据采集 | `true` | +| **sa_order** | str | 计算attention_hook内指标时,指定Attention输入(Q,K)的张量维度排列顺序,支持"s,b,h,d"和"b,s,h,d", 默认为"s,b,h,d" | `"s,b,h,d"` | #### L2可解释特征监控指标说明 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py index 1c163a58d..960f3dabe 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py @@ -170,7 +170,7 @@ def cal_entropy(qk_tensor, mask=None): qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf')) softmax_qkt = torch.nn.functional.softmax(qk_tensor.float(), dim=1) # softmax取QK矩阵最大值 - softmax_max = torch.max(torch.amax(softmax_qkt, dim=1)) + softmax_max = torch.mean(torch.amax(softmax_qkt, dim=1)) entropy = torch.mean(-torch.nansum(softmax_qkt * torch.log(softmax_qkt), dim=1)) return entropy, softmax_max diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py index a4f771ec1..a2e7ecb51 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py @@ -119,7 +119,7 @@ class TestMathFunctions(unittest.TestCase): mask = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]], dtype=torch.float) - entropy, _ = cal_entropy(qk, mask) + entropy, softmax_max = cal_entropy(qk, mask) self.assertAlmostEqual(entropy, 0.4715, delta=0.1) self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 7e7e3aeff..586280f74 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -8,7 +8,7 @@ from msprobe.core.common.const import MonitorConst from msprobe.core.monitor.utils import filter_special_chars, MsgConst, validate_ops, validate_ranks, \ validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ validate_mg_distribution, validate_wg_distribution, validate_cc_distribution, validate_alert, validate_config, \ - get_output_base_dir, validate_l2_targets, validate_recording_l2_features + get_output_base_dir, validate_l2_targets, validate_recording_l2_features, validate_sa_order from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.common.utils import is_recomputation @@ -138,7 +138,7 @@ class TestValidationFunctions(unittest.TestCase): def test_validate_l2_targets_invalid_value_type(self): """测试非法 value 类型""" with self.assertRaises(TypeError) as cm: - validate_l2_targets({"hook1": "not_a_list"}) + validate_l2_targets({"linear_hook": "not_a_list"}) self.assertEqual(str(cm.exception), 'values of l2_targets should be a list in config.json') @@ -161,7 +161,16 @@ class TestValidationFunctions(unittest.TestCase): validate_recording_l2_features("xx") self.assertEqual(str(cm.exception), "recording_l2_features should be a bool") + + def test_valid_orders(self): + validate_sa_order("b,s,h,d") # 不应报错 + validate_sa_order("s,b,h,d") # 不应报错 + def test_invalid_orders(self): + with self.assertRaises(TypeError) as cm: + validate_recording_l2_features("xx") + self.assertEqual(str(cm.exception), + f'sa_order must be in {MonitorConst.SA_ORDERS}, got xx') class TestIsRecomputation(unittest.TestCase): @patch('inspect.stack') -- Gitee From c8dfed74c75d493f996906db48b24a6c4dde6f1c Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 15 Aug 2025 16:35:21 +0800 Subject: [PATCH 07/11] bugfix --- debug/accuracy_tools/msprobe/core/monitor/utils.py | 4 +++- .../test/pytorch_ut/monitor/test_monitor_utils.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py index 658ccb2b1..111f4808d 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -116,6 +116,8 @@ def validate_recording_l2_features(recording_l2_features): raise TypeError("recording_l2_features should be a bool") def validate_sa_order(sa_order): + if isinstance(sa_order, str): + sa_order = sa_order.replace(' ', '') if sa_order not in MonitorConst.SA_ORDERS: raise TypeError(f'sa_order must be in {MonitorConst.SA_ORDERS}, got {sa_order}') @@ -243,7 +245,7 @@ def validate_config(config): recording_l2_features = config.get("recording_l2_features", False) validate_recording_l2_features(recording_l2_features) - sa_order = config.get("sa_order", False) + sa_order = config.get("sa_order", "s,b,h,d") validate_sa_order(sa_order) print_struct = config.get('print_struct', False) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 586280f74..2f10d4f12 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -119,7 +119,7 @@ class TestValidationFunctions(unittest.TestCase): "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], "linear_hook": [] } - validate_l2_targets(valid_targets) # 不应抛出异常 + validate_l2_targets(valid_targets) def test_validate_l2_targets_invalid_root_type(self): """测试非 dict 输入""" @@ -152,8 +152,8 @@ class TestValidationFunctions(unittest.TestCase): # ===== validate_recording_l2_features 测试 ===== def test_validate_recording_l2_features_valid(self): """测试合法布尔值输入""" - validate_recording_l2_features(True) # 不应抛出异常 - validate_recording_l2_features(False) # 不应抛出异常 + validate_recording_l2_features(True) + validate_recording_l2_features(False) def test_validate_recording_l2_features_invalid_type(self): """测试非法类型输入""" @@ -163,8 +163,8 @@ class TestValidationFunctions(unittest.TestCase): "recording_l2_features should be a bool") def test_valid_orders(self): - validate_sa_order("b,s,h,d") # 不应报错 - validate_sa_order("s,b,h,d") # 不应报错 + validate_sa_order("b,s,h,d") + validate_sa_order("s, b,h, d") def test_invalid_orders(self): with self.assertRaises(TypeError) as cm: -- Gitee From d0d403caf7a4b0d62c22f8334b1b3434b6fdac46 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 15 Aug 2025 16:38:54 +0800 Subject: [PATCH 08/11] bugfix --- debug/accuracy_tools/msprobe/core/monitor/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py index 111f4808d..2c7a20a04 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -115,12 +115,14 @@ def validate_recording_l2_features(recording_l2_features): if not isinstance(recording_l2_features, bool): raise TypeError("recording_l2_features should be a bool") + def validate_sa_order(sa_order): if isinstance(sa_order, str): sa_order = sa_order.replace(' ', '') if sa_order not in MonitorConst.SA_ORDERS: raise TypeError(f'sa_order must be in {MonitorConst.SA_ORDERS}, got {sa_order}') + def validate_print_struct(print_struct): if not isinstance(print_struct, bool): raise TypeError("print_struct should be a bool") -- Gitee From 30772fe4756b1039de3cea3b33f699901678a755 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Mon, 18 Aug 2025 10:48:18 +0800 Subject: [PATCH 09/11] bugfix --- .../accuracy_tools/msprobe/pytorch/monitor/module_hook.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index f108591cb..3e967318e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -1102,7 +1102,7 @@ class TrainerMon: qkt = cal_qkt(q_h, k_h, order=self.sa_order) tbtag_tensor_map.update( self.build_tbtag_tensor_map(f'{context.module_name}.attention', - '', 'qkt', qkt) + f'{MonitorConst.NAME_SEP}{context.micro_step}', 'qkt', qkt) ) get_entropy_metric(tbtag_tensor_map, context.attention_feature) @@ -1123,7 +1123,7 @@ class TrainerMon: self.feature_hook_context_by_module[module] = FeatureHookContext(name) context: FeatureHookContext = self.feature_hook_context_by_module[module] - if context.micro_step == self.micro_batch_number: + if context.micro_step == (self.micro_batch_number - 1): tbtag_tensor_map = {} value = getattr(module, weight_name).data tbtag_tensor_map.update( @@ -1131,9 +1131,11 @@ class TrainerMon: '', 'sr', value) ) get_sr_metric(tbtag_tensor_map, context.linear_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: context.micro_step = 0 context.step += 1 - context.micro_step += 1 return def stack_hook(module, args, kwargs, module_output, name): -- Gitee From e60cdf10370c6223f9b4d58aaae43528f881ff6d Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Mon, 18 Aug 2025 20:22:50 +0800 Subject: [PATCH 10/11] suggestions --- debug/accuracy_tools/msprobe/docs/19.monitor.md | 13 +++++-------- .../msprobe/pytorch/monitor/module_hook.py | 15 ++++++++------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 96f421371..1128bf55d 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -304,7 +304,7 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。 ``` ### l2可解释特征监控 -- 工具配置 +- 工具配置示例 ```json { "l2_targets": { @@ -315,12 +315,11 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。 "sa_order": "b,s,h,d" } ``` -| 配置项 | 类型 | 说明 | 示例值 | +| 配置项 | 类型 | 说明 | 是否必选 | |--------|------|------|--------| -| **l2_targets.attention_hook** | List[str] | 指定需要监控的注意力层, 采集"entropy"和"sorftmax_max"指标,需要通过[打印模型结构功能](#打印模型结构)获取 | `["0:0.self_attention.core_attention.flash_attention"]` | -| **l2_targets.linear_hook** | List[str] | 指定需要监控的线性层, 采集"sr"和 "kernel_norm"指标,需要通过[打印模型结构功能](#打印模型结构)获取,支持传入空列表自动识别线性模块 | `["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"]` | -| **recording_l2_features** | bool | 是否开启L2层特征数据采集 | `true` | -| **sa_order** | str | 计算attention_hook内指标时,指定Attention输入(Q,K)的张量维度排列顺序,支持"s,b,h,d"和"b,s,h,d", 默认为"s,b,h,d" | `"s,b,h,d"` | +| **l2_targets** | Dict[str, List[str]] | 指定需要监控的模型层配置
**支持的hook类型**:
• `attention_hook`:监控注意力层
  ▪️ 采集指标:`entropy` `softmax_max`
  ▪️ 必须通过[模型结构打印](#模型结构打印)获取准确层名
  ▪️ 不配置或配置空列表均表示不采集
• `linear_hook`:监控线性层
  ▪️ 采集指标:`sr`, `kernel_norm`
  ▪️ 必须通过[模型结构打印](#模型结构打印)获取准确层名, 不配置表示不采集
  ▪️ 配置空列表会自动识别符合条件的层(包含`weight`或`wg`2D参数属性的层) | 是 | +| **recording_l2_features** | bool | 是否开启L2层特征数据采集,默认为false表示不采集 | 否 | +| **sa_order** | str | 计算`attention_hook`内指标时,指定Attention输入(Q,K)的张量维度排列顺序,支持"s,b,h,d"和"b,s,h,d", 默认为"s,b,h,d"表示输入维度顺序为**s**equence_len​->**b**atch_size​->num_**h**eads​->head_**d**im | 否 | #### L2可解释特征监控指标说明 @@ -333,8 +332,6 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。 | **kernel_norm** | linear_hook | $\|W\|_F$(Frobenius范数) | 权重矩阵的缩谱范数,反映输入在矩阵最大奇异向量张成空间的放大系数 | - - ### 输出格式和统计量 工具配置示例: diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 3e967318e..0a2cd447c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -566,12 +566,10 @@ class TrainerMon: self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) def write_metrics_if_not_empty(self, features, metrics, step, hook_name): - if len(features) == 0: + if not features or len(features) == 0: return - if hook_name in ["linear_hook"]: - self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=False) - else: - self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=True) + use_micro_step = hook_name not in ["linear_hook"] + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=use_micro_step) features.clear() def write_features_tb(self, step): @@ -1095,8 +1093,11 @@ class TrainerMon: context: FeatureHookContext = self.feature_hook_context_by_module[module] tbtag_tensor_map = {} if len(module_input) < 2: - raise ValueError("the length of module_input in attention hook's module " - "should be greater than or equal to 2.") + logger.warning( + f"Length of module_input in attention hook ({name}) is {len(module_input)}, " + "expected >= 2. Skipping feature extraction for this module." + ) + return q_h = module_input[0] k_h = module_input[1] qkt = cal_qkt(q_h, k_h, order=self.sa_order) -- Gitee From 5d45e263f7905a9fd4c120508946bb9888a16ebe Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Mon, 18 Aug 2025 20:28:25 +0800 Subject: [PATCH 11/11] suggestions --- debug/accuracy_tools/msprobe/docs/19.monitor.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 1128bf55d..69b9e5f23 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -317,7 +317,7 @@ param_name可以通过nn.Module的接口`named_parameters()`获取。 ``` | 配置项 | 类型 | 说明 | 是否必选 | |--------|------|------|--------| -| **l2_targets** | Dict[str, List[str]] | 指定需要监控的模型层配置
**支持的hook类型**:
• `attention_hook`:监控注意力层
  ▪️ 采集指标:`entropy` `softmax_max`
  ▪️ 必须通过[模型结构打印](#模型结构打印)获取准确层名
  ▪️ 不配置或配置空列表均表示不采集
• `linear_hook`:监控线性层
  ▪️ 采集指标:`sr`, `kernel_norm`
  ▪️ 必须通过[模型结构打印](#模型结构打印)获取准确层名, 不配置表示不采集
  ▪️ 配置空列表会自动识别符合条件的层(包含`weight`或`wg`2D参数属性的层) | 是 | +| **l2_targets** | Dict[str, List[str]] | 指定需要监控的模型层配置
**支持的hook类型**:
• `attention_hook`:监控注意力层
  ▪️ 采集指标:`entropy` `softmax_max`
  ▪️ 必须通过[打印模型结构](#打印模型结构)获取准确层名
  ▪️ 不配置或配置空列表均表示不采集
• `linear_hook`:监控线性层
  ▪️ 采集指标:`sr`, `kernel_norm`
  ▪️ 必须通过[打印模型结构](#打印模型结构)获取准确层名, 不配置表示不采集
  ▪️ 配置空列表会自动识别符合条件的层(包含`weight`或`wg`2D参数属性的层) | 是 | | **recording_l2_features** | bool | 是否开启L2层特征数据采集,默认为false表示不采集 | 否 | | **sa_order** | str | 计算`attention_hook`内指标时,指定Attention输入(Q,K)的张量维度排列顺序,支持"s,b,h,d"和"b,s,h,d", 默认为"s,b,h,d"表示输入维度顺序为**s**equence_len​->**b**atch_size​->num_**h**eads​->head_**d**im | 否 | -- Gitee