From 0fa6e0023c0aca0891f7473f2478872de64be771 Mon Sep 17 00:00:00 2001 From: wangqingcai Date: Thu, 28 Nov 2024 11:09:04 +0800 Subject: [PATCH] add monitor anomaly detected flag --- .../msprobe/pytorch/monitor/anomaly_analyse.py | 3 +++ .../msprobe/pytorch/monitor/module_hook.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py index fc21ded951..6fec808d14 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py @@ -65,6 +65,9 @@ class AnomalyDataWriter: Args: anomalies: GradAnomalyData对象列表 """ + if not isinstance(anomalies, list) or anomalies == []: + logger.warning("The anomalies should be a list of GradAnomalyData and not empty.") + return anomalies_json = self.get_anomaly_dict(anomalies) logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index cb40b99be0..ab80307e34 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -228,6 +228,8 @@ class TrainerMon: self.anomaly_data_factory, self.ndigits ) + # 初始化anomaly_detect状态 + self.is_anomaly_detected = False # 初始化anomaly detected文件目录 if self.anomaly_data_factory: self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"), rank) @@ -491,6 +493,9 @@ class TrainerMon: self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced') self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced') + def get_anomaly_status(self): + return self.is_anomaly_detected + def hook_optimizer(self, optimizer=None): # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): @@ -547,6 +552,8 @@ class TrainerMon: metric_dict.update(cc.data) cc.reset() + # clear anomaly detection status + self.is_anomaly_detected = False if not metric_dict: return context.metric_dict = metric_dict @@ -576,8 +583,13 @@ class TrainerMon: self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other') context.metric_dict.clear() context.step += 1 + + anomaly_data = self.summary_writer.get_anomalies() + if anomaly_data: + # set anomaly detected status + self.is_anomaly_detected = True if self.anomaly_data_factory: - self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) + self.anomaly_data_writer.write_detected_json(anomaly_data) self.summary_writer.clear_anomalies() self.call_id = 0 return -- Gitee