diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index 76c47a7edc710ce5dc9cd103ba99b51b46e8b807..eb0f65f1c6c265fccb175b71ef1104b5d7e46e52 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -48,7 +48,9 @@ class OptimizerContext: def __init__(self) -> None: self.step = 0 self.param_gnorm = defaultdict(float) + self.param_gsign = defaultdict(int) self.param_exp_avg_norm = defaultdict(float) + self.param_exp_avg_sign = defaultdict(int) self.param_exp_avg_sq_norm = defaultdict(float) self.param_effective_rank = defaultdict(float) self.param_adam_update = defaultdict() @@ -69,6 +71,7 @@ class TrainerMon: self.config = get_config(config_file_path) self.module_rank_list = [int(rank) for rank in self.config.get("module_ranks", "").split(',') if rank.strip()] self.ur_distribution = self.config.get('ur_distribution', False) + self.mg_direction = self.config.get('mg_direction', False) self.optimizer_hooked = False output_base_dir = os.getenv('KJ600_OUTPUT_DIR', './kj600_output') @@ -137,7 +140,7 @@ class TrainerMon: context.verified = True if not context.ignore_in: cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] - cared_input_grad_cal_result = square_sum(cared_input_grad) + cared_input_grad_cal_result = square_sum(cared_input_grad) if cared_input_grad is not None else torch.tensor(0.) else: cared_input_grad_cal_result = None cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] @@ -191,14 +194,26 @@ class TrainerMon: # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] + rank = dist.get_rank() if dist.is_initialized() else None + + context.param_exp_avg_norm, context.param_exp_avg_sign, context.param_exp_avg_sq_norm, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv( + optimizer, self.param2name, self.update_heatmap_visualizer, self.ratio_heatmap_visualizer, self.ur_distribution, self.mg_direction) + for param, name in self.param2name.items(): grad_for_norm = param.main_grad if self.params_have_main_grad else param.grad context.param_gnorm[name] = grad_for_norm.detach().norm() if "params_effrank" in self.config and name in self.config["params_effrank"]: context.param_effective_rank[name] = eff_rank(param.detach()) - context.param_exp_avg_norm, context.param_exp_avg_sq_norm, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv( - optimizer, self.param2name, self.update_heatmap_visualizer, self.ratio_heatmap_visualizer, self.ur_distribution) + if self.mg_direction: + if context.step == 0: + self.summary_writer.add_scalar(get_summary_writer_tag_name(name, 'adam_mg_direction', rank), 1, context.step) + continue + g_sign = grad_for_norm.detach().sign() + m_sign = context.param_exp_avg_sign[name] + same_direction_ratio = ((m_sign * g_sign).sum().item()/m_sign.numel() + 1)/2 + self.summary_writer.add_scalar(get_summary_writer_tag_name(name, 'adam_mg_direction', rank), same_direction_ratio, context.step) + return def optimizer_post_step_hook(optimizer, args, kwargs): diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 96288149716b2c06c0982bd06d131a8fcb4a3977..44f478416cc30054d566d2167426095eed941210 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -25,7 +25,7 @@ class MixPrecsionOptimizerMon: # parameter tensors we want to monitor and their names are in params2name_dict # base_optimizer is pytorch optimizer, wrapped_optimizer is a normal object with base_optimizer - def fetch_mv(self, torch_opt, params2name, update_heatmap_visualizer, ratio_heatmap_visualizer, ur_distribution): + def fetch_mv(self, torch_opt, params2name, update_heatmap_visualizer, ratio_heatmap_visualizer, ur_distribution, mg_direction): mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer if not self.fp16_to_fp32_param and mix_prec_opt is not None: @@ -34,6 +34,7 @@ class MixPrecsionOptimizerMon: self.fp16_to_fp32_param[fp16_param] = fp32_param exp_avg_norm_dict = defaultdict(float) + exp_avg_sign_dict = defaultdict(int) exp_avg_sq_norm_dict = defaultdict(float) update_dict = defaultdict() ratio_dict = defaultdict() @@ -49,10 +50,12 @@ class MixPrecsionOptimizerMon: exp_avg_sq_norm = exp_avg_sq.detach().norm() exp_avg_norm_dict[name] = exp_avg_norm exp_avg_sq_norm_dict[name] = exp_avg_sq_norm + if mg_direction: + exp_avg_sign_dict[name] = exp_avg.detach().sign() if ur_distribution: update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps']) ratio_dict[name] = exp_avg / torch.sqrt(exp_avg_sq) update_heatmap_visualizer[name].pre_cal(update_dict[name]) ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) - - return exp_avg_norm_dict, exp_avg_sq_norm_dict, update_dict, ratio_dict + + return exp_avg_norm_dict, exp_avg_sign_dict, exp_avg_sq_norm_dict, update_dict, ratio_dict diff --git a/debug/accuracy_tools/kj600/kj600/unittest/config_1.json b/debug/accuracy_tools/kj600/kj600/unittest/config_1.json index fc6196fc191c4f5792f772858b6d1e25374edc03..a3b10f731d10b64a8b2df703079f9c56080876eb 100644 --- a/debug/accuracy_tools/kj600/kj600/unittest/config_1.json +++ b/debug/accuracy_tools/kj600/kj600/unittest/config_1.json @@ -2,5 +2,7 @@ "targets": { "fc": {"input": "tuple[1]:0", "output": "tensor", "input_grad":"tuple[1]:0", "output_grad":"tuple[1]:0"}, "relu": {"input": "tuple[1]:0", "output": "tensor", "input_grad":"tuple[1]:0", "output_grad":"tuple[1]:0"} - } + }, + "ur_distribution": true, + "mg_direction": true } \ No newline at end of file