diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index c9b1c54102039e72bf1352e4aa7248ae76329fcd..92c90051e68434d011a348fce0f8c5a808882ba5 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -84,15 +84,12 @@ class GradContext: def __init__(self) -> None: self.pre = [] self.post = [] - self.grad_acc = None self.vpp_stage = 0 self.call_id = 0 def reset(self): self.pre.clear() self.post.clear() - self.grad_acc.fill_(0.) - class TrainerMon: @@ -150,7 +147,9 @@ class TrainerMon: anomaly_inform = AnomalyInformFactory.create_informer(**alert_setting["inform"]) if "inform" in alert_setting else None self.optimizer_hooked = False - self.vpp = False + self.weight_hooked = False + self.param_registered = False + output_base_dir = os.getenv('KJ600_OUTPUT_DIR', './kj600_output') cur_time = datetime.now().strftime('%b%d_%H-%M-%S') unique_id = str(uuid.uuid4())[:8] @@ -183,10 +182,11 @@ class TrainerMon: # A HeatmapVisualizer instance is associated with an image self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) - self.micro_batch_number = 0 - self.param_name_list = [] + self.micro_batch_number = 1 + self.vpp = False self.param2name = defaultdict(str) + self.call_id_handles = [] self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty) if opt_ty is None: @@ -244,12 +244,29 @@ class TrainerMon: self.hook_optimizer() return - def monitor_gnorm_with_ad(self, model, grad_acc_steps, process_group=None): - self.micro_batch_number = grad_acc_steps - - self._hook_weights(model) + def monitor_gnorm_with_ad(self, model, grad_acc_steps): + if (dist.is_initialized() and dist.get_rank() not in self.module_rank_list): + return + self.hook_optimizer() + if not isinstance(model, list): + model = [model] + + if len(model) > 1: + self.vpp = True + self._smallest_rank_print('vpp enabled') + + if self.print_struct: + for vpp_stage, model_chunk in enumerate(model): + vpp_stage = f'{vpp_stage}_' if self.vpp else '' + self.module_struct.update({vpp_stage+module_name:{} for module_name, module in model_chunk.named_modules()}) + return + self._register_param_name(model) + + if self.wg_distribution: + self._hook_model_for_grad_acc(model) + def build_tbtag_tensor_map(self, module_name, tag, tensor): metrics = {} rank = dist.get_rank() if dist.is_initialized() else None @@ -267,7 +284,26 @@ class TrainerMon: continue metrics[key] = param_tensor[name] return metrics - + + def generate_grad_metric(self, tag): + + rank = dist.get_rank() if dist.is_initialized() else None + for param, name in self.param2name.items(): + context = self.grad_context[name] + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + continue + key = get_summary_writer_tag_name(name, tag, rank) + metric_dict = {} + for metric_name in self.ops: + metric_dict[metric_name] = get_metrics(metric_name, {key: grad}, self.eps) + + if tag == 'post_grad': + context.post.append(metric_dict) + elif tag == 'pre_grad': + context.pre.append(metric_dict) + def generate_cc_metrics(self, cc_name, cc_tensor): metrics = defaultdict(dict) rank = dist.get_rank() if dist.is_initialized() else None @@ -301,11 +337,12 @@ class TrainerMon: def write_grad_tb(self, step): if not self.wg_distribution: return - + for name in self.param2name.values(): context = self.grad_context[name] # 将当前parameter的call_id vpp_stage等信息告知异常生成类 - self.anomaly_data_factory.set_context(context) + if self.anomaly_data_factory: + self.anomaly_data_factory.set_context(context) for metric_name in self.ops: write_metrics_tensorboard(metric_name, self.summary_writer, context.pre, step) write_metrics_tensorboard(metric_name, self.summary_writer, context.post, step) @@ -329,28 +366,18 @@ class TrainerMon: context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name) - rank = dist.get_rank() if dist.is_initialized() else None - for param, name in self.param2name.items(): - if "params_effrank" in self.config and name in self.config["params_effrank"]: - context.param_effective_rank[name] = eff_rank(param.detach()) - grad = param.main_grad if self.params_have_main_grad else param.grad - if grad is None: - print_warn_log(f"grad is None: {name}, maybe something wrong happened.") - continue - if self.wg_distribution: - metric_dict = {} - key = get_summary_writer_tag_name(name, 'post_grad', rank) - for metric_name in self.ops: - metric_dict[metric_name] = get_metrics(metric_name, {key: grad}, self.eps) - self.grad_context[name].post.append(metric_dict) - - metric_dict = {} - key = get_summary_writer_tag_name(name, 'pre_grad', rank) - for metric_name in self.ops: - metric_dict[metric_name] = get_metrics(metric_name, {key: self.grad_context[name].grad_acc}, self.eps) - self.grad_context[name].pre.append(metric_dict) - - if self.mg_direction: + if self.wg_distribution: + self.generate_grad_metric(tag='post_grad') + + if self.mg_direction: + for param, name in self.param2name.items(): + if "params_effrank" in self.config and name in self.config["params_effrank"]: + context.param_effective_rank[name] = eff_rank(param.detach()) + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + continue + if context.step == 0: same_direction_ratio = torch.tensor(1.) else: @@ -392,19 +419,26 @@ class TrainerMon: for param_name, _ in context.param_adam_ratio.items(): self.ratio_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer) - for metric_name in self.ops: - if not context.metric_list: - break - write_metrics_tensorboard(metric_name, self.summary_writer, context.metric_list, context.step) + if context.metric_list: + for metric_name in self.ops: + write_metrics_tensorboard(metric_name, self.summary_writer, context.metric_list, context.step) + context.metric_list.clear() context.step += 1 self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) self.summary_writer.clear_anomalies() self.call_id = 0 + for handle in self.call_id_handles: + handle.remove() + return + + if self.optimizer_hooked: return if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): register_optimizer_step_pre_hook(optimizer_pre_step_hook) register_optimizer_step_post_hook(optimizer_post_step_hook) + + self.optimizer_hooked = True return def _smallest_rank_print(self, msg): @@ -511,41 +545,48 @@ class TrainerMon: hooked_count += 1 return hooked_count - def _hook_weights(self, model): - self.wg_distribution = True - - def param_hook(grad, context): - with torch.no_grad(): - context.grad_acc += grad - context.call_id = self.call_id - self.call_id += 1 - - def register_hooks(model_chunk, vpp_stage=None): + def _register_param_name(self, model): + if self.param_registered: + return + + print_rank_0("> parameter names:") + for vpp_stage, model_chunk in enumerate(model): + prefix = f'{vpp_stage}_' if self.vpp else '' for param_name, param in model_chunk.named_parameters(): - prefix = "" if not self.vpp else f"vpp{vpp_stage}_" name = prefix + param_name for target in self.config['targets'].keys(): - context = self.grad_context[name] - context.vpp_stage = 0 if vpp_stage is None else vpp_stage - if param_name.startswith(target) and param.requires_grad: + if param.requires_grad and (param_name.startswith(target) or name.startswith(target)): self._smallest_rank_print(f'>> monitoring: {name}') self.param2name[param] = name - param.register_hook(partial(param_hook, context=context)) - context.grad_acc = torch.zeros_like(param).to(DEVICE) + context = self.grad_context[name] + context.vpp_stage = vpp_stage + break + self.param_registered = True - model = [model] if not isinstance(model, list) else model - if len(model) > 1: - self.vpp = True - self._smallest_rank_print('vpp enabled') - - if self.print_struct: - for vpp_stage, model_chunk in enumerate(model): - prefix = "" if not self.vpp else f"vpp{vpp_stage}_" - self.module_struct = { - prefix + f"{module_name}": {} for module_name, _ in model_chunk.named_modules()} + def _hook_weights(self): + if self.weight_hooked: return + + def param_hook(grad, context): + context.call_id = self.call_id + self.call_id += 1 + + for param, name in self.param2name.items(): + context = self.grad_context[name] + self.call_id_handles.append(param.register_hook(partial(param_hook, context=context))) + + self.weight_hooked = True + + def _hook_model_for_grad_acc(self, model): + def model_backward_hook(module, input_grad, output_grad): + model_chunk.micro_step += 1 + if model_chunk.micro_step == (self.micro_batch_number): + model_chunk.micro_step = 0 + self.generate_grad_metric(tag='pre_grad') - for index, model_chunk in enumerate(model): - vpp_stage = index if self.vpp else 0 - register_hooks(model_chunk, vpp_stage=vpp_stage) - \ No newline at end of file + + for model_chunk in model: + setattr(model_chunk,'micro_step', 0) + model_chunk.register_full_backward_hook(model_backward_hook) + + self._hook_weights()