From b3c5877626b22e6eb22732d627f2ec333f4f640e Mon Sep 17 00:00:00 2001 From: pengxiaopeng Date: Sun, 28 Apr 2024 15:22:04 +0800 Subject: [PATCH] support adam update and ratio heatmap --- debug/accuracy_tools/kj600/README.md | 6 +-- debug/accuracy_tools/kj600/kj600/features.py | 17 +++++--- .../accuracy_tools/kj600/kj600/module_hook.py | 17 ++++++-- .../kj600/kj600/optimizer_collect.py | 18 ++++++-- .../accuracy_tools/kj600/kj600/visualizer.py | 41 +++++++++++++++++++ 5 files changed, 85 insertions(+), 14 deletions(-) create mode 100644 debug/accuracy_tools/kj600/kj600/visualizer.py diff --git a/debug/accuracy_tools/kj600/README.md b/debug/accuracy_tools/kj600/README.md index 701b5abb6..d03b9606e 100644 --- a/debug/accuracy_tools/kj600/README.md +++ b/debug/accuracy_tools/kj600/README.md @@ -41,8 +41,8 @@ pip install -e . "targets": { "language_model.encoder.layers.0": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"} }, - "module_ranks":"1,2,3,4", - "optimizer_ranks":"1,2,3,4" + "module_ranks": "1,2,3,4", + "ur_distribution": true } ``` @@ -62,7 +62,7 @@ pip install -e . "module_ranks":可选字段,用于在分布式训练场景中希望控制在哪些rank开启module监控。如果不填,则默认在所有rank开启。 -"optimizer_ranks":可选字段,用于在分布式训练场景中希望控制在哪些rank开启optimizer监控。如果不填,则默认在所有rank开启。 +"ur_distribution": 可选字段,若为true则会统计adam优化器的update和ratio的数值分布,并展示在heatmap里,默认为false。 下面给出transformer架构模型中常见的module的前向计算的输入输出和反向计算输入张量的梯度和输出张量的梯度格式,以供参考: diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/kj600/kj600/features.py index 7e726b522..b4fc8f308 100644 --- a/debug/accuracy_tools/kj600/kj600/features.py +++ b/debug/accuracy_tools/kj600/kj600/features.py @@ -2,12 +2,12 @@ import torch from torch.autograd.functional import jacobian -@torch.no_grad +@torch.no_grad() def square_sum(x: torch.tensor): return (x * x).sum() -@torch.no_grad +@torch.no_grad() def eff_rank(param: torch.tensor, threshold=1e-10): U, S, Vh = torch.linalg.svd(param.float()) rank = torch.sum(S > threshold) @@ -15,14 +15,14 @@ def eff_rank(param: torch.tensor, threshold=1e-10): # modular neural tangent kernel -@torch.no_grad +@torch.no_grad() def mNTK(module: torch.nn.Module, x: torch.tensor): J_theta_l = jacobian(module, x) mntk = torch.matmul(J_theta_l, J_theta_l.t()) return mntk -@torch.no_grad +@torch.no_grad() def power_iteration(A, num_iterations): b = torch.randn(A.size(1), 1) for _ in range(num_iterations): @@ -33,7 +33,7 @@ def power_iteration(A, num_iterations): return eigval -@torch.no_grad +@torch.no_grad() def lambda_max_subsample(module: torch.nn.Module, x: torch.tensor, num_iterations=100, subsample_size=None): mntk = mNTK(module, x) if subsample_size is None: @@ -43,3 +43,10 @@ def lambda_max_subsample(module: torch.nn.Module, x: torch.tensor, num_iteration subsampled = subsampled[:, idx] eigval = power_iteration(subsampled, num_iterations) return eigval + + +@torch.no_grad() +def cal_histc(tensor_cal, bins_total, min_val, max_val): + return torch.histc(tensor_cal, bins=bins_total, min=min_val, max=max_val) + + diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index 570799ff7..76c47a7ed 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -12,6 +12,7 @@ from kj600.features import square_sum from kj600.module_spec_verifier import get_config, validate_config_spec from kj600.optimizer_collect import MixPrecsionOptimizerMon, print_rank_0 from kj600.features import eff_rank +from kj600.visualizer import HeatmapVisualizer def get_summary_writer_tag_name(module_or_param_name:str, tag:str, rank): @@ -50,6 +51,8 @@ class OptimizerContext: self.param_exp_avg_norm = defaultdict(float) self.param_exp_avg_sq_norm = defaultdict(float) self.param_effective_rank = defaultdict(float) + self.param_adam_update = defaultdict() + self.param_adam_ratio = defaultdict() class TrainerMon: @@ -65,6 +68,7 @@ class TrainerMon: self.params_have_main_grad = True self.config = get_config(config_file_path) self.module_rank_list = [int(rank) for rank in self.config.get("module_ranks", "").split(',') if rank.strip()] + self.ur_distribution = self.config.get('ur_distribution', False) self.optimizer_hooked = False output_base_dir = os.getenv('KJ600_OUTPUT_DIR', './kj600_output') @@ -75,6 +79,9 @@ class TrainerMon: self.summary_writer = SummaryWriter(os.path.join(output_base_dir, f"{cur_time}-rank{dist.get_rank()}-{unique_id}")) else: self.summary_writer = SummaryWriter(os.path.join(output_base_dir, f"{cur_time}-{unique_id}")) + # A HeatmapVisualizer instance is associated with an image + self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) + self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.micro_batch_number = 0 self.param_name_list = [] @@ -190,8 +197,8 @@ class TrainerMon: if "params_effrank" in self.config and name in self.config["params_effrank"]: context.param_effective_rank[name] = eff_rank(param.detach()) - context.param_exp_avg_norm, context.param_exp_avg_sq_norm = self.mix_precision_optimizer_mon.fetch_mv( - optimizer, self.param2name) + context.param_exp_avg_norm, context.param_exp_avg_sq_norm, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv( + optimizer, self.param2name, self.update_heatmap_visualizer, self.ratio_heatmap_visualizer, self.ur_distribution) return def optimizer_post_step_hook(optimizer, args, kwargs): @@ -224,7 +231,11 @@ class TrainerMon: self.summary_writer.add_scalar(get_summary_writer_tag_name(param_name, 'exp_avg_norm', rank), exp_avg_norm.item(), context.step) for param_name, exp_avg_sq_norm in context.param_exp_avg_sq_norm.items(): self.summary_writer.add_scalar(get_summary_writer_tag_name(param_name, 'exp_avg_sq_norm', rank), exp_avg_sq_norm.item(), context.step) - + if self.ur_distribution: + for param_name, _ in context.param_adam_update.items(): + self.update_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer) + for param_name, _ in context.param_adam_ratio.items(): + self.ratio_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer) context.step += 1 return diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index d0991718c..962881497 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -1,5 +1,8 @@ from collections import defaultdict +import torch import torch.distributed as dist +from kj600.visualizer import HeatmapVisualizer + def print_rank_0(message, debug=False, force=False): @@ -19,8 +22,10 @@ class MixPrecsionOptimizerMon: def __init__(self) -> None: self.fp16_to_fp32_param = {} - - def fetch_mv(self, torch_opt, params2name): + + # parameter tensors we want to monitor and their names are in params2name_dict + # base_optimizer is pytorch optimizer, wrapped_optimizer is a normal object with base_optimizer + def fetch_mv(self, torch_opt, params2name, update_heatmap_visualizer, ratio_heatmap_visualizer, ur_distribution): mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer if not self.fp16_to_fp32_param and mix_prec_opt is not None: @@ -30,6 +35,8 @@ class MixPrecsionOptimizerMon: exp_avg_norm_dict = defaultdict(float) exp_avg_sq_norm_dict = defaultdict(float) + update_dict = defaultdict() + ratio_dict = defaultdict() for param, name in params2name.items(): if param in self.fp16_to_fp32_param: @@ -42,5 +49,10 @@ class MixPrecsionOptimizerMon: exp_avg_sq_norm = exp_avg_sq.detach().norm() exp_avg_norm_dict[name] = exp_avg_norm exp_avg_sq_norm_dict[name] = exp_avg_sq_norm + if ur_distribution: + update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps']) + ratio_dict[name] = exp_avg / torch.sqrt(exp_avg_sq) + update_heatmap_visualizer[name].pre_cal(update_dict[name]) + ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) - return exp_avg_norm_dict, exp_avg_sq_norm_dict \ No newline at end of file + return exp_avg_norm_dict, exp_avg_sq_norm_dict, update_dict, ratio_dict diff --git a/debug/accuracy_tools/kj600/kj600/visualizer.py b/debug/accuracy_tools/kj600/kj600/visualizer.py new file mode 100644 index 000000000..e1929bfa3 --- /dev/null +++ b/debug/accuracy_tools/kj600/kj600/visualizer.py @@ -0,0 +1,41 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from kj600.features import cal_histc + + +class HeatmapVisualizer: + def __init__(self) -> None: + self.histogram_bins_num = 30 + self.min_val = -1 + self.max_val = 1 + self.histogram_edges = None + self.histogram_sum_data_np = None # matrix shape is [bins_num * total_step] + self.cur_step_histogram_data = None + self.histogram_edges = torch.linspace(self.min_val, self.max_val, self.histogram_bins_num) + + def pre_cal(self, tensor): + self.cur_step_histogram_data = cal_histc(tensor_cal=tensor, bins_total=self.histogram_bins_num, min_val=self.min_val, max_val=self.max_val) + + def visualize(self, tag_name:str, step, summary_writer): + if self.histogram_sum_data_np is None or self.histogram_sum_data_np.size == 0: + self.histogram_sum_data_np = np.expand_dims(self.cur_step_histogram_data.cpu(), 0).T + else: + # add new data along a different axis because we transposed early + # matrix shape is [bins_num * total_step] + self.histogram_sum_data_np = np.concatenate((self.histogram_sum_data_np, np.expand_dims(self.cur_step_histogram_data.cpu(), 1)), axis=1) + + fig, ax = plt.subplots() + cax = ax.matshow(self.histogram_sum_data_np, cmap='hot', aspect='auto') + fig.colorbar(cax) + + plt.yticks(ticks=range(self.histogram_bins_num), labels=[f'{self.histogram_edges[i]:.2f}' for i in range(self.histogram_bins_num)]) + ax.set_xlabel('Step') + ax.set_ylabel('Value Range') + plt.title(f'Total Step: {step}') + + # Convert matplotlib figure to an image format suitable for TensorBoard + fig.canvas.draw() + image = torch.from_numpy(np.array(fig.canvas.renderer.buffer_rgba())) + plt.close(fig) + summary_writer.add_image(tag_name, image.permute(2, 0, 1), global_step=step, dataformats='CHW') -- Gitee