From 4dc1527531710545889e60ca2a335482b8cbf93a Mon Sep 17 00:00:00 2001 From: qijie Date: Wed, 8 May 2024 06:44:36 +0000 Subject: [PATCH] support g/m direction compare --- .../accuracy_tools/kj600/kj600/module_hook.py | 21 ++++++++++++++++--- .../kj600/kj600/optimizer_collect.py | 9 +++++--- .../kj600/kj600/unittest/config_1.json | 4 +++- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index 76c47a7edc..eb0f65f1c6 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -48,7 +48,9 @@ class OptimizerContext: def __init__(self) -> None: self.step = 0 self.param_gnorm = defaultdict(float) + self.param_gsign = defaultdict(int) self.param_exp_avg_norm = defaultdict(float) + self.param_exp_avg_sign = defaultdict(int) self.param_exp_avg_sq_norm = defaultdict(float) self.param_effective_rank = defaultdict(float) self.param_adam_update = defaultdict() @@ -69,6 +71,7 @@ class TrainerMon: self.config = get_config(config_file_path) self.module_rank_list = [int(rank) for rank in self.config.get("module_ranks", "").split(',') if rank.strip()] self.ur_distribution = self.config.get('ur_distribution', False) + self.mg_direction = self.config.get('mg_direction', False) self.optimizer_hooked = False output_base_dir = os.getenv('KJ600_OUTPUT_DIR', './kj600_output') @@ -137,7 +140,7 @@ class TrainerMon: context.verified = True if not context.ignore_in: cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] - cared_input_grad_cal_result = square_sum(cared_input_grad) + cared_input_grad_cal_result = square_sum(cared_input_grad) if cared_input_grad is not None else torch.tensor(0.) else: cared_input_grad_cal_result = None cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] @@ -191,14 +194,26 @@ class TrainerMon: # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] + rank = dist.get_rank() if dist.is_initialized() else None + + context.param_exp_avg_norm, context.param_exp_avg_sign, context.param_exp_avg_sq_norm, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv( + optimizer, self.param2name, self.update_heatmap_visualizer, self.ratio_heatmap_visualizer, self.ur_distribution, self.mg_direction) + for param, name in self.param2name.items(): grad_for_norm = param.main_grad if self.params_have_main_grad else param.grad context.param_gnorm[name] = grad_for_norm.detach().norm() if "params_effrank" in self.config and name in self.config["params_effrank"]: context.param_effective_rank[name] = eff_rank(param.detach()) - context.param_exp_avg_norm, context.param_exp_avg_sq_norm, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv( - optimizer, self.param2name, self.update_heatmap_visualizer, self.ratio_heatmap_visualizer, self.ur_distribution) + if self.mg_direction: + if context.step == 0: + self.summary_writer.add_scalar(get_summary_writer_tag_name(name, 'adam_mg_direction', rank), 1, context.step) + continue + g_sign = grad_for_norm.detach().sign() + m_sign = context.param_exp_avg_sign[name] + same_direction_ratio = ((m_sign * g_sign).sum().item()/m_sign.numel() + 1)/2 + self.summary_writer.add_scalar(get_summary_writer_tag_name(name, 'adam_mg_direction', rank), same_direction_ratio, context.step) + return def optimizer_post_step_hook(optimizer, args, kwargs): diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 9628814971..44f478416c 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -25,7 +25,7 @@ class MixPrecsionOptimizerMon: # parameter tensors we want to monitor and their names are in params2name_dict # base_optimizer is pytorch optimizer, wrapped_optimizer is a normal object with base_optimizer - def fetch_mv(self, torch_opt, params2name, update_heatmap_visualizer, ratio_heatmap_visualizer, ur_distribution): + def fetch_mv(self, torch_opt, params2name, update_heatmap_visualizer, ratio_heatmap_visualizer, ur_distribution, mg_direction): mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer if not self.fp16_to_fp32_param and mix_prec_opt is not None: @@ -34,6 +34,7 @@ class MixPrecsionOptimizerMon: self.fp16_to_fp32_param[fp16_param] = fp32_param exp_avg_norm_dict = defaultdict(float) + exp_avg_sign_dict = defaultdict(int) exp_avg_sq_norm_dict = defaultdict(float) update_dict = defaultdict() ratio_dict = defaultdict() @@ -49,10 +50,12 @@ class MixPrecsionOptimizerMon: exp_avg_sq_norm = exp_avg_sq.detach().norm() exp_avg_norm_dict[name] = exp_avg_norm exp_avg_sq_norm_dict[name] = exp_avg_sq_norm + if mg_direction: + exp_avg_sign_dict[name] = exp_avg.detach().sign() if ur_distribution: update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps']) ratio_dict[name] = exp_avg / torch.sqrt(exp_avg_sq) update_heatmap_visualizer[name].pre_cal(update_dict[name]) ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) - - return exp_avg_norm_dict, exp_avg_sq_norm_dict, update_dict, ratio_dict + + return exp_avg_norm_dict, exp_avg_sign_dict, exp_avg_sq_norm_dict, update_dict, ratio_dict diff --git a/debug/accuracy_tools/kj600/kj600/unittest/config_1.json b/debug/accuracy_tools/kj600/kj600/unittest/config_1.json index fc6196fc19..a3b10f731d 100644 --- a/debug/accuracy_tools/kj600/kj600/unittest/config_1.json +++ b/debug/accuracy_tools/kj600/kj600/unittest/config_1.json @@ -2,5 +2,7 @@ "targets": { "fc": {"input": "tuple[1]:0", "output": "tensor", "input_grad":"tuple[1]:0", "output_grad":"tuple[1]:0"}, "relu": {"input": "tuple[1]:0", "output": "tensor", "input_grad":"tuple[1]:0", "output_grad":"tuple[1]:0"} - } + }, + "ur_distribution": true, + "mg_direction": true } \ No newline at end of file -- Gitee