diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py index 2ebbcb9480b8c21692b77ec0388cd621979e4346..2f71f4529755c4784f3aa048f878e7ddc69aab4d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py @@ -65,6 +65,9 @@ class AnomalyDataWriter: Args: anomalies: GradAnomalyData对象列表 """ + if not isinstance(anomalies, list) or anomalies == []: + logger.warning("The anomalies should be a list of GradAnomalyData and not empty.") + return anomalies_json = self.get_anomaly_dict(anomalies) logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index cfe0fb3fc153bf84500c29d642d249a040a1fb20..00f6bbc56dad045e7615100fa59e7ded3748a073 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -226,6 +226,8 @@ class TrainerMon: self.anomaly_data_factory, self.ndigits ) + # 初始化anomaly_detect状态 + self.is_anomaly_detected = False # 初始化anomaly detected文件目录 if self.anomaly_data_factory: self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"), rank) @@ -481,6 +483,9 @@ class TrainerMon: self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced') self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced') + def get_anomaly_status(self): + return self.is_anomaly_detected + def hook_optimizer(self, optimizer=None): # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): @@ -536,6 +541,8 @@ class TrainerMon: metric_dict.update(cc.data) cc.reset() + # clear anomaly detection status + self.is_anomaly_detected = False if not metric_dict: return context.metric_dict = metric_dict @@ -564,8 +571,13 @@ class TrainerMon: self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other') context.metric_dict.clear() context.step += 1 + + anomaly_data = self.summary_writer.get_anomalies() + if anomaly_data: + # set anomaly detected status + self.is_anomaly_detected = True if self.anomaly_data_factory: - self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) + self.anomaly_data_writer.write_detected_json(anomaly_data) self.summary_writer.clear_anomalies() self.call_id = 0 return