From 053a25a8dd9a64e4edf701f178c3bbd0247db06c Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Mon, 28 Jul 2025 17:34:51 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E3=80=90feature=E3=80=91monitor=E6=94=AF?= =?UTF-8?q?=E6=8C=81fsdp2=E6=A2=AF=E5=BA=A6=E9=87=87=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/mindspore/monitor/module_hook.py | 2 + .../msprobe/pytorch/monitor/data_writers.py | 3 ++ .../msprobe/pytorch/monitor/module_hook.py | 41 ++++++++++++++++++- .../pytorch/monitor/optimizer_collect.py | 2 +- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 375314bff..be9f2e0e4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -483,6 +483,8 @@ class TrainerMon: if (self.print_struct and not all(value == {} for value in self.module_struct.values()) and not self.struct_printed): self._save_module_struct() + if not self.cc_log_only: + raise Exception("exit after first monitor step when print model struct") if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, self.collect_times): return diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py index bd6bde7e9..6310cc182 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py @@ -80,6 +80,9 @@ class BaseWriterWithAD: xpu_stack = torch.stack(xpu_tensors).cpu() if xpu_tensors else torch.tensor([]) + if xpu_stack.__class__.__name__ == 'DTensor': + xpu_stack = xpu_stack.to_local() + # 按照输入的顺序恢复 result = [] cpu_tensors_idx, xpu_tensors_idx = 0, 0 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 45011e59f..77a5f84c0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -15,6 +15,7 @@ import json import os import uuid +import importlib from collections import defaultdict from datetime import datetime from functools import partial @@ -50,6 +51,7 @@ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if not torch_version_above_or_equal_2: raise ValueError("monitor require torch>=2.0") + FORMAT_MAPPING = { MonitorConst.TENSORBOARD: SummaryWriterWithAD, MonitorConst.CSV: CSVWriterWithAD, @@ -160,6 +162,7 @@ class TrainerMon: self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.origin_start_grad_sync = None self.fsdp_post_backward_hook = None + self.fsdp2_foreach_reduce = None self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开 self.config = load_json(config_file_path) validate_config(self.config) @@ -195,6 +198,7 @@ class TrainerMon: self.tp_group = None self.enable_megatron = False self.fsdp_wrapped_module = False + self.fsdp2_wrapped_module = False self.micro_batch_number = 1 self.optimizer_mon = None self.optimizer_trans = None @@ -209,6 +213,7 @@ class TrainerMon: self.param2name = defaultdict(str) self.name2indices = defaultdict() self.name2param = {} + self.origin2squash = {} self.duplicate_param = {} self.name2tag = {} self.param_name_call_id = {} @@ -735,7 +740,6 @@ class TrainerMon: hook(optimizer, args, kwargs) step_final_hook(optimizer, args, kwargs) return out - return wrapper optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) @@ -820,6 +824,9 @@ class TrainerMon: elif self.fsdp_post_backward_hook: # fsdp torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook logger.info("remove patch_post_backward_hook in fsdp.") + if self.fsdp2_foreach_reduce: # fsdp + torch.distributed.fsdp._fully_shard._fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce + logger.info("remove patch_foreach_reduce_hook in fsdp2.") else: # not megatron and not fsdp for handle in self.handles['wgrads']: handle.remove() @@ -904,12 +911,15 @@ class TrainerMon: continue if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"): self.fsdp_wrapped_module = True + if not self.fsdp2_wrapped_module and param.__class__.__name__ == "DTensor": + self.fsdp2_wrapped_module = True if self._is_target_param(param_name, param, prefix): name = prefix + squash_param_name(param_name, self.squash_name) if name in self.param2name.values(): name = prefix + param_name self.param2name[param] = name self.name2param[name] = param + self.origin2squash[param_name] = name if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group): self.duplicate_param[name] = True @@ -1088,6 +1098,11 @@ class TrainerMon: self._patch_fsdp_post_backward_hook() return + if self.fsdp2_wrapped_module: + # patch fsdp _runtime_utils._post_backward_hook + self._patch_fsdp2_foreach_reduce() + return + if self.monitor_mbs_grad: self._hook_weights() return @@ -1121,7 +1136,6 @@ class TrainerMon: 每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效, 因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。 """ - def patch_post_backward_hook(_post_backward_hook): def wrapper(state, handle, *unused): grad_dict = {} @@ -1148,6 +1162,29 @@ class TrainerMon: torch.distributed.fsdp._runtime_utils._post_backward_hook = \ patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook) + def _patch_fsdp2_foreach_reduce(self): + def patch_foreach_reduce(foreach_reduce): + def wrapper(fsdp_params, unsharded_grads, *unused): + grad_dict = {} + for param, grad in zip(fsdp_params, unsharded_grads): + tag = self.name2tag.get(self.origin2squash[param._param_fqn], {}).get(MonitorConst.PRE_GRAD) + if tag is None: + continue + grad_dict[tag] = grad + self.register_param_call_id("foreach_reduce", tag) + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) + out = foreach_reduce(fsdp_params, unsharded_grads, *unused) + return out + + return wrapper + + logger.info("Patch fsdp foreach_reduce, collect pre_grad metrics.") + import torch.distributed.fsdp._fully_shard._fsdp_param_group as _fsdp_param_group + import torch.distributed.fsdp._fully_shard._fsdp_collectives as _fsdp_collectives + self.fsdp_foreach_reduce = _fsdp_collectives.foreach_reduce + _fsdp_collectives.foreach_reduce = patch_foreach_reduce(_fsdp_collectives.foreach_reduce) + importlib.reload(_fsdp_param_group) # 关键操作,不然会因为torch一开始就import foreach_reduce导致patch失效 + def _hook_weights(self): """ 遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index c2991839e..5d7c5d821 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -53,7 +53,7 @@ class OptimizerMon(object): if param.numel() != element_in_cur_partition: if first_param: grad = grad.flatten()[-element_in_cur_partition:] - else: # supposed to be the last one + else: # supposed to be the last one grad = grad.flatten()[:element_in_cur_partition] first_param = False -- Gitee From 9b8850777cf654ee7115a5614d504e9b684c87be Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Thu, 31 Jul 2025 11:29:40 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E3=80=90feature=E3=80=91monitor=E6=94=AF?= =?UTF-8?q?=E6=8C=81fsdp2=E6=A2=AF=E5=BA=A6=E9=87=87=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit aa2698ca985b6edb36f8650de064beedadc8f662) --- debug/accuracy_tools/msprobe/docs/19.monitor.md | 15 ++++++++++++++- .../msprobe/pytorch/monitor/module_hook.py | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 2374ef768..86199b73b 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -403,6 +403,17 @@ export MONITOR_OUTPUT_DIR=/xxx/output_dir }, ``` +注:当配置多条异常告警规则时,优先告警第一条,如以下配置时每一层会优先报AnomalyNan的告警(一般不建议配置多条规则): +```json + "alert": { + "rules": [ + {"rule_name": "AnomalyNan", "args": {"threshold": 1e10}}, + {"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}} + ], + "dump": true + }, +``` + 2. 实例化工具时传入流水线并行group ```python monitor = TrainerMon( @@ -430,12 +441,14 @@ monitor = TrainerMon( } ``` +其中call_{xxx}中的xxx为API的执行调用顺序,未后续异常时间排序做准备。 + 3. 异常事件排序 当模型训练过程中出现较多异常数据,需要对异常事件排序。工具提供topk的异常排序能力,按照api的执行顺序进行排序,便于定界首次异常点。异常分析命令示例: ```shell -python3 -m msprobe.pytorch.monitor.anomaly_analyse -d $MONITOR_OUTPUT_DIR/anomaly_detected +python3 -m msprobe.pytorch.monitor.anomaly_processor -d $MONITOR_OUTPUT_DIR/anomaly_detected ``` 异常事件分析结束,将topk事件写入文件`anomaly_detected/anomaly_analyse.json`。异常分析支持以下参数配置: diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 77a5f84c0..880bff36e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -824,7 +824,7 @@ class TrainerMon: elif self.fsdp_post_backward_hook: # fsdp torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook logger.info("remove patch_post_backward_hook in fsdp.") - if self.fsdp2_foreach_reduce: # fsdp + if self.fsdp2_foreach_reduce: # fsdp2 torch.distributed.fsdp._fully_shard._fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce logger.info("remove patch_foreach_reduce_hook in fsdp2.") else: # not megatron and not fsdp -- Gitee From 3257bf87ff4a6f1bfb55a945493d61006a38c67d Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Thu, 31 Jul 2025 11:37:47 +0800 Subject: [PATCH 3/7] fix remove (cherry picked from commit 9abe0785c5c3b4df38cb2fc5231e65eb9b532b56) --- .../msprobe/pytorch/monitor/module_hook.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 880bff36e..282042566 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -825,7 +825,11 @@ class TrainerMon: torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook logger.info("remove patch_post_backward_hook in fsdp.") if self.fsdp2_foreach_reduce: # fsdp2 - torch.distributed.fsdp._fully_shard._fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce + import torch.distributed.fsdp._fully_shard._fsdp_param_group as _fsdp_param_group + import torch.distributed.fsdp._fully_shard._fsdp_collectives as _fsdp_collectives + _fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce + importlib.reload(_fsdp_param_group) + logger.info("remove patch_foreach_reduce_hook in fsdp2.") else: # not megatron and not fsdp for handle in self.handles['wgrads']: @@ -1099,7 +1103,7 @@ class TrainerMon: return if self.fsdp2_wrapped_module: - # patch fsdp _runtime_utils._post_backward_hook + # patch fsdp _fully_shard._fsdp_collectives.foreach_reduce self._patch_fsdp2_foreach_reduce() return @@ -1154,7 +1158,6 @@ class TrainerMon: get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) out = _post_backward_hook(state, handle, *unused) return out - return wrapper logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.") @@ -1175,13 +1178,12 @@ class TrainerMon: get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) out = foreach_reduce(fsdp_params, unsharded_grads, *unused) return out - return wrapper logger.info("Patch fsdp foreach_reduce, collect pre_grad metrics.") import torch.distributed.fsdp._fully_shard._fsdp_param_group as _fsdp_param_group import torch.distributed.fsdp._fully_shard._fsdp_collectives as _fsdp_collectives - self.fsdp_foreach_reduce = _fsdp_collectives.foreach_reduce + self.fsdp2_foreach_reduce = _fsdp_collectives.foreach_reduce _fsdp_collectives.foreach_reduce = patch_foreach_reduce(_fsdp_collectives.foreach_reduce) importlib.reload(_fsdp_param_group) # 关键操作,不然会因为torch一开始就import foreach_reduce导致patch失效 -- Gitee From 6f7591bec68ce208d2a841146384bf1746cb8387 Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Thu, 31 Jul 2025 11:38:26 +0800 Subject: [PATCH 4/7] fix remove (cherry picked from commit ba917dbbc0d5d33432cf0f5a567b7f5602c78a66) --- debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 282042566..9abcc3a18 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -825,10 +825,8 @@ class TrainerMon: torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook logger.info("remove patch_post_backward_hook in fsdp.") if self.fsdp2_foreach_reduce: # fsdp2 - import torch.distributed.fsdp._fully_shard._fsdp_param_group as _fsdp_param_group - import torch.distributed.fsdp._fully_shard._fsdp_collectives as _fsdp_collectives - _fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce - importlib.reload(_fsdp_param_group) + torch.distributed.fsdp._fully_shard._fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce + importlib.reload(torch.distributed.fsdp._fully_shard._fsdp_param_group) logger.info("remove patch_foreach_reduce_hook in fsdp2.") else: # not megatron and not fsdp -- Gitee From 90060ca7dd9647774cc0246ed3f39ef093f372e1 Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Thu, 31 Jul 2025 15:18:34 +0800 Subject: [PATCH 5/7] fix (cherry picked from commit 85856a718005625e6a4f3c39b17bcd668b09d820) --- debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 9abcc3a18..3e0721d36 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -827,7 +827,6 @@ class TrainerMon: if self.fsdp2_foreach_reduce: # fsdp2 torch.distributed.fsdp._fully_shard._fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce importlib.reload(torch.distributed.fsdp._fully_shard._fsdp_param_group) - logger.info("remove patch_foreach_reduce_hook in fsdp2.") else: # not megatron and not fsdp for handle in self.handles['wgrads']: @@ -1101,7 +1100,7 @@ class TrainerMon: return if self.fsdp2_wrapped_module: - # patch fsdp _fully_shard._fsdp_collectives.foreach_reduce + # patch fsdp2 _fully_shard._fsdp_collectives.foreach_reduce self._patch_fsdp2_foreach_reduce() return @@ -1178,7 +1177,7 @@ class TrainerMon: return out return wrapper - logger.info("Patch fsdp foreach_reduce, collect pre_grad metrics.") + logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.") import torch.distributed.fsdp._fully_shard._fsdp_param_group as _fsdp_param_group import torch.distributed.fsdp._fully_shard._fsdp_collectives as _fsdp_collectives self.fsdp2_foreach_reduce = _fsdp_collectives.foreach_reduce -- Gitee From 51491e22d48f60b204a954426b814152b244d4ff Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Thu, 31 Jul 2025 16:11:22 +0800 Subject: [PATCH 6/7] fix (cherry picked from commit 3be80a11f1473d137ae001fb3e7b4b05d9110d37) --- debug/accuracy_tools/msprobe/docs/19.monitor.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 86199b73b..429556cb7 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -441,7 +441,7 @@ monitor = TrainerMon( } ``` -其中call_{xxx}中的xxx为API的执行调用顺序,未后续异常时间排序做准备。 +其中call_{xxx}中的xxx为API的执行调用顺序,未后续异常事件排序做准备。 3. 异常事件排序 -- Gitee From 6c0386375770f893befd733d2bdb201077f08913 Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Thu, 31 Jul 2025 19:33:56 +0800 Subject: [PATCH 7/7] =?UTF-8?q?fix=E6=A3=80=E8=A7=86=E6=84=8F=E8=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit e85db7e776f518fe01108a91d591e618e3cdab4b) --- debug/accuracy_tools/msprobe/docs/19.monitor.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 429556cb7..3d4be725d 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -441,14 +441,14 @@ monitor = TrainerMon( } ``` -其中call_{xxx}中的xxx为API的执行调用顺序,未后续异常事件排序做准备。 +其中call_{xxx}中的xxx为API的执行调用顺序,为后续异常事件排序做准备。 3. 异常事件排序 当模型训练过程中出现较多异常数据,需要对异常事件排序。工具提供topk的异常排序能力,按照api的执行顺序进行排序,便于定界首次异常点。异常分析命令示例: ```shell -python3 -m msprobe.pytorch.monitor.anomaly_processor -d $MONITOR_OUTPUT_DIR/anomaly_detected +python3 -m msprobe.core.monitor.anomaly_processor -d $MONITOR_OUTPUT_DIR/anomaly_detected ``` 异常事件分析结束,将topk事件写入文件`anomaly_detected/anomaly_analyse.json`。异常分析支持以下参数配置: -- Gitee