diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 7337d6f47a842e5e744a2cf3a5f467613eb12fe2..555c0b4759876ed1f9b7743331b27da6822ab4f9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -803,7 +803,8 @@ class TrainerMon: f"maybe something wrong happened. Now clear it.") context.actvgrad.clear() - get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv) + # get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv) + self.grad_context.actv.update(get_metrics(self.ops, tbtag_tensor_map, self.eps)) context.micro_step += 1 if context.micro_step == self.micro_batch_number: diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index 8082d8933cfc394ef57e45045264a96a8688f271..d40dc59969c0f9515b42997f229c81197d797514 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -165,7 +165,7 @@ def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''): for op2tensor in metric_value.values(): tensors.extend(op2tensor.values()) with torch.no_grad(): - metric_list = torch.stack(tensors).squeeze().cpu() + metric_list = torch.stack(tensors).cpu() for tag, metric in zip(tags, metric_list): summary_writer.add_scalar(tag, metric, step) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index c13987ebcf553f754dbff45ae6016f5339e104ee..b020049503be6f1fd62789b4780fd1cdbf2f322c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -88,10 +88,10 @@ class OptimizerMon(ABC): partition_id = dist.get_rank() def get_flatten_grad(self, optimizer, group_idx): - if self.is_stage3 or optimizer.cpu_offload: + if not self.is_stage3 and optimizer.cpu_offload: return fp32_partitioned_groups_flat[group_idx].grad elif fp32_partitioned_groups_flat[group_idx].grad is None: - if partition_id == dist.get_world_size() - 1: + if partition_id == dist.get_world_size() - 1 and not self.is_stage3: fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned( optimizer.averaged_gradients[group_idx], int(optimizer.partition_size[group_idx])