From 9ce0b1c7a3c2c04abbfc737f4fb5e510b10d86dc Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Tue, 19 Aug 2025 16:48:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96fsdp2=20Dtensor=E7=9A=84?= =?UTF-8?q?=E8=90=BD=E7=9B=98=EF=BC=8C=E6=8F=90=E5=89=8D=E8=BD=AClocal?= =?UTF-8?q?=E9=98=B2=E6=AD=A2shape=E4=B8=8D=E5=AF=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit b390e59b9fb73f9dd76c039dbba6b5e530e6cabd) --- debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py | 3 --- .../msprobe/pytorch/monitor/optimizer_collect.py | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py index 6310cc1824..bd6bde7e9f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py @@ -80,9 +80,6 @@ class BaseWriterWithAD: xpu_stack = torch.stack(xpu_tensors).cpu() if xpu_tensors else torch.tensor([]) - if xpu_stack.__class__.__name__ == 'DTensor': - xpu_stack = xpu_stack.to_local() - # 按照输入的顺序恢复 result = [] cpu_tensors_idx, xpu_tensors_idx = 0, 0 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 5d7c5d8219..afd242aa41 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -49,6 +49,8 @@ class OptimizerMon(object): if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param: continue grad = param.main_grad if monitor.params_have_main_grad else param.grad + if grad.__class__.__name__ == 'DTensor': + grad = grad.to_local() element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel() if param.numel() != element_in_cur_partition: if first_param: -- Gitee