diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py index fc21ded9515ed1c792c72d3de423e8dcf8e56e7c..6fec808d14072b46b15fbdba8541727413d6a530 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py @@ -65,6 +65,9 @@ class AnomalyDataWriter: Args: anomalies: GradAnomalyData对象列表 """ + if not isinstance(anomalies, list) or anomalies == []: + logger.warning("The anomalies should be a list of GradAnomalyData and not empty.") + return anomalies_json = self.get_anomaly_dict(anomalies) logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index cb40b99be027a96c77630e47d7384af4b8045c1b..ab80307e344afdad9cf06b6a7822ace654be9cff 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -228,6 +228,8 @@ class TrainerMon: self.anomaly_data_factory, self.ndigits ) + # 初始化anomaly_detect状态 + self.is_anomaly_detected = False # 初始化anomaly detected文件目录 if self.anomaly_data_factory: self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"), rank) @@ -491,6 +493,9 @@ class TrainerMon: self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced') self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced') + def get_anomaly_status(self): + return self.is_anomaly_detected + def hook_optimizer(self, optimizer=None): # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): @@ -547,6 +552,8 @@ class TrainerMon: metric_dict.update(cc.data) cc.reset() + # clear anomaly detection status + self.is_anomaly_detected = False if not metric_dict: return context.metric_dict = metric_dict @@ -576,8 +583,13 @@ class TrainerMon: self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other') context.metric_dict.clear() context.step += 1 + + anomaly_data = self.summary_writer.get_anomalies() + if anomaly_data: + # set anomaly detected status + self.is_anomaly_detected = True if self.anomaly_data_factory: - self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) + self.anomaly_data_writer.write_detected_json(anomaly_data) self.summary_writer.clear_anomalies() self.call_id = 0 return