diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py index 6310cc18247c4d4b3dc2cf0940f20e431807f69b..bd6bde7e9f6ede789f520acc2138492e99bac509 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py @@ -80,9 +80,6 @@ class BaseWriterWithAD: xpu_stack = torch.stack(xpu_tensors).cpu() if xpu_tensors else torch.tensor([]) - if xpu_stack.__class__.__name__ == 'DTensor': - xpu_stack = xpu_stack.to_local() - # 按照输入的顺序恢复 result = [] cpu_tensors_idx, xpu_tensors_idx = 0, 0 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 5d7c5d8219e70bb5df3c7bd15de9ec4d4b356a4e..afd242aa4123b1299a39dce2995567ad9b329428 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -49,6 +49,8 @@ class OptimizerMon(object): if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param: continue grad = param.main_grad if monitor.params_have_main_grad else param.grad + if grad.__class__.__name__ == 'DTensor': + grad = grad.to_local() element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel() if param.numel() != element_in_cur_partition: if first_param: