diff --git a/debug/accuracy_tools/msprobe/docs/32.checkpoint_compare.md b/debug/accuracy_tools/msprobe/docs/32.checkpoint_compare.md index c49b4bfc8ee079cfdf2583c0c84372fe74aec6a7..c374b7822ea11623408a3b4b1e396bdd1035efed 100644 --- a/debug/accuracy_tools/msprobe/docs/32.checkpoint_compare.md +++ b/debug/accuracy_tools/msprobe/docs/32.checkpoint_compare.md @@ -1,6 +1,6 @@ -# 权重比对 +# 权重比对和权重度量 -msprobe 工具提供大模型权重比对功能。当前支持pytorch下megatron/mindspeed不同模型并行策略下的权重互相比对。 +msprobe 工具提供大模型权重比对和权重度量功能。当前支持pytorch下megatron/mindspeed不同模型并行策略下的权重互相比对和度量。 > **Attention:** Ensure megatron in the PYTHONPATH to load a megatron checkpoint. @@ -10,16 +10,21 @@ msprobe 工具提供大模型权重比对功能。当前支持pytorch下megatron ## 2. 工具使用 ```shell +# 权重对比 msprobe -f pytorch config_checking -c PATH/TO/A/CHECKPOINT PATH/TO/THE/OTHER/CHECKPOINT -s -o PATH/FOR/OUTPUT + +# 权重度量 +msprobe -f pytorch config_checking -m PATH/TO/A/CHECKPOINT -o PATH/FOR/OUTPUT ``` **命令行参数说明**: -| 参数名 | 说明 | 是否必选 | -|------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| -| -c --compare | 需要比较的两个checkpoint路径 | 是 | -| -s --ckpt-sim | store_true。使能权重比对功能,否则为配置比对 | 是 | -| -o 或 --out | 权重比对结果文件存盘目录,默认为'ckpt_compare_out.json' | 否 | +| 参数名 | 说明 | 是否必选 | +|---------------|-----------------------------------------|------| +| -c --compare | 需要比较的两个checkpoint路径 | 是 | +| -m --measure | 需要度量的checkpoint路径 | 是 | +| -s --ckpt-sim | store_true。使能权重比对功能,否则为配置比对 | 否 | +| -o 或 --out | 权重比对结果文件存盘目录,默认为'ckpt_compare_out.json' | 否 | @@ -37,6 +42,8 @@ Found xxx total parameters across all ranks ``` Sample result: + +权重对比结果: ```json { "embedding.word_embeddings.weight": { @@ -57,4 +64,20 @@ Sample result: ] } } +``` + +权重度量结果: +```json +{ + "embedding.word_embeddings.weight": { + "svd_entropy": 7.5917921325684, + "max_singular_value": 24.088369369506836, + "stable_rank": 11.633524894714355 + }, + "decoder.layers.0.self_attention.linear_qkv.weight": { + "svd_entropy": 7.364636285400391, + "max_singular_value": 24.019309997558594, + "stable_rank": 4.037281036376953 + } +} ``` \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/compare_weight.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/compare_weight.py index b4c49fc3a8e0ed7838de451f9e8dcfbcf4363388..4bb4c8184388d549153aedd8f0291b55d12e8f16 100644 --- a/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/compare_weight.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/compare_weight.py @@ -19,7 +19,7 @@ from tqdm import tqdm from msprobe.core.common.file_utils import save_json, check_file_or_directory_path from msprobe.pytorch.common.log import logger from msprobe.pytorch.config_checking.ckpt_compare.megatron_loader import load_megatron_weights -from msprobe.pytorch.config_checking.ckpt_compare.metrics import METRIC_FUNC +from msprobe.pytorch.config_checking.ckpt_compare.metrics import METRIC_FUNC, MEASURING_FUNC def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict: @@ -69,3 +69,49 @@ def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict: save_json(output_path, results, indent=4) logger.info(f"Comparison results written to {output_path}") return results + + +def measure_checkpoints(ckpt_path, output_path) -> Dict: + """Compare weights between two checkpoints using cosine similarity and L2 distance. + + Args: + ckpt_path (str): Path to first checkpoint directory + output_path (str): Path to save comparison results JSON file + + Returns: + Dict: Dictionary containing metrics for each parameter. The dictionary has the following structure: + { + "param_name": { + "svd_entropy": float, # Svd entropy of parameter tensors + }, + ... + } + """ + + # Load both checkpoints + check_file_or_directory_path(output_path) + weights = load_megatron_weights(ckpt_path) + + # Initialize results dictionary + results = {} + + for key in tqdm(weights): + tensor = weights[key] + + results[key] = {} + for metric, func in MEASURING_FUNC.items(): + try: + if metric == 'svd_entropy': + entropy, s0 = func(tensor) + results[key][metric] = entropy + results[key]['max_singular_value'] = s0 + else: + results[key][metric] = func(tensor) + except Exception as e: + results[key][metric] = 'error' + logger.warning(f'Error when calculate {metric} for reason: {e}') + + # Write results to JSON file + save_json(output_path, results, indent=4) + logger.info(f"Measure results written to {output_path}") + return results diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/metrics.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/metrics.py index 65b5feb659f2fc515d5f2f57faf107d65937d16c..3c9e05649d6ee0236855db45e662606434730fda 100644 --- a/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/metrics.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/metrics.py @@ -15,12 +15,23 @@ import torch from torch.nn import functional as F - +import numpy as np from msprobe.pytorch.common.log import logger MAX_SLICE = 1000000 +try: + import torch_npu +except ImportError: + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') +else: + device = torch.device('npu') + + def in_different_shape(a, b): if a.shape != b.shape: logger.warning(f"a, b are in different shape. a: {a.shape}, b: {b.shape}") @@ -87,9 +98,70 @@ def shape(a, b): return list(s1) +def svd_entropy(param): + if isinstance(param, np.ndarray): + param = torch.from_numpy(param) + param = param.to(device) + if param.ndim == 2: + u, s, v = torch.svd_lowrank(param.float(), q=200) + p = s / torch.sum(s) + entropy = -torch.sum(p * torch.log2(p)).item() + s0 = s[0].item() + elif param.ndim == 3: + entropy = [] + s0 = [] + for i in range(param.shape[0]): + u, s, v = torch.svd_lowrank(param[i, ...].float(), q=200) + p = s / torch.sum(s) + entropy.append(-torch.sum(p * torch.log2(p)).item()) + s0.append(s[0].item()) + else: + entropy = float('nan') + s0 = float('nan') + return entropy, s0 + + +def max_eigenvalue(tensor, num_iterations=3): + tensor = tensor.float() + in_features = tensor.shape[1] + u = torch.randn(in_features).to(tensor.device) + u = u / u.norm() + input_seq = torch.matmul(tensor.T, tensor) + for _ in range(num_iterations): + v = torch.matmul(input_seq, u) + spectral_norm = torch.matmul(v.T, u) + u = v / v.norm() + return spectral_norm.sqrt() + + +def stable_rank(param): + if isinstance(param, np.ndarray): + param = torch.from_numpy(param) + param = param.to(device) + if param.ndim == 2: + eig = max_eigenvalue(param) + f_norm = torch.norm(param, p='fro') + sr = (f_norm / eig).item() + elif param.ndim == 3: + sr = [] + for i in range(param.shape[0]): + eig = max_eigenvalue(param[i, ...]) + f_norm = torch.norm(param[i, ...], p='fro') + sr = (f_norm / eig).item() + else: + sr = float('nan') + return sr + + METRIC_FUNC = { 'l2': l2_distance, 'cos': cos_sim, 'numel': numel, 'shape': shape - } \ No newline at end of file + } + + +MEASURING_FUNC = { + 'svd_entropy': svd_entropy, + 'stable_rank': stable_rank +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checking.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checking.py index a8cc15ab6ee36907bdc7a061cd04359b4b83ebf8..f02c09e611545e7a98efab7df4b43acee4da7a34 100644 --- a/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checking.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checking.py @@ -15,7 +15,7 @@ from msprobe.pytorch.common.log import logger from msprobe.pytorch.config_checking.config_checker import ConfigChecker -from msprobe.pytorch.config_checking.ckpt_compare.compare_weight import compare_checkpoints +from msprobe.pytorch.config_checking.ckpt_compare.compare_weight import compare_checkpoints, measure_checkpoints def pack(config_filepath): @@ -29,7 +29,8 @@ def compare(bench_zip_path, cmp_zip_path, outpath): def _config_checking_parser(parser): parser.add_argument('-pack', '--pack', help='Pack a directory into a zip file') parser.add_argument('-c', '--compare', nargs=2, help='Compare two zip files or ckpt dir') - parser.add_argument('-s', '--ckpt-sim', default=False, action='store_true', + parser.add_argument('-m', '--measure', nargs=1, help='Measure checkpoints using svd entropy') + parser.add_argument('-s', '--ckpt-sim', default=False, action='store_true', help='Calculate the similarity of two ckpt') parser.add_argument('-o', '--output', help='output path, default is current directory') @@ -44,6 +45,10 @@ def _run_config_checking_command(args): else: output_dirpath = args.output if args.output else "./config_check_result" compare(args.compare[0], args.compare[1], output_dirpath) + elif args.measure: + logger.info(f"Measure checkpoints") + output_path = args.output if args.output else "./ckpt_measure_out.json" + measure_checkpoints(args.measure[0], output_path) else: logger.error("The param is not correct, you need to give '-pack' for pack or '-c' for compare.") raise Exception("The param is not correct, you need to give '-pack' for pack or '-c' for compare.")