From 21cf7be49ba9803f426e08e1137faf8dddf36aec Mon Sep 17 00:00:00 2001 From: qianggee Date: Sun, 24 Nov 2024 03:33:53 +0000 Subject: [PATCH] revert --- debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py | 3 ++- debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py | 2 +- .../msprobe/pytorch/monitor/optimizer_collect.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 7337d6f47a..555c0b4759 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -803,7 +803,8 @@ class TrainerMon: f"maybe something wrong happened. Now clear it.") context.actvgrad.clear() - get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv) + # get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv) + self.grad_context.actv.update(get_metrics(self.ops, tbtag_tensor_map, self.eps)) context.micro_step += 1 if context.micro_step == self.micro_batch_number: diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index 8082d8933c..d40dc59969 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -165,7 +165,7 @@ def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''): for op2tensor in metric_value.values(): tensors.extend(op2tensor.values()) with torch.no_grad(): - metric_list = torch.stack(tensors).squeeze().cpu() + metric_list = torch.stack(tensors).cpu() for tag, metric in zip(tags, metric_list): summary_writer.add_scalar(tag, metric, step) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index c13987ebcf..b020049503 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -88,10 +88,10 @@ class OptimizerMon(ABC): partition_id = dist.get_rank() def get_flatten_grad(self, optimizer, group_idx): - if self.is_stage3 or optimizer.cpu_offload: + if not self.is_stage3 and optimizer.cpu_offload: return fp32_partitioned_groups_flat[group_idx].grad elif fp32_partitioned_groups_flat[group_idx].grad is None: - if partition_id == dist.get_world_size() - 1: + if partition_id == dist.get_world_size() - 1 and not self.is_stage3: fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned( optimizer.averaged_gradients[group_idx], int(optimizer.partition_size[group_idx]) -- Gitee