From 6dcebd346196e1971b607df5b262e6ffbcc58b32 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Tue, 13 Aug 2024 20:34:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BC=98=E5=8C=96=E5=99=A8?= =?UTF-8?q?=E7=9B=91=E8=A7=86=E5=99=A8=E5=9F=BA=E7=B1=BBOptimizerMon?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=8A=A0MegatronFP32OptimizerMon=E7=B1=BB?= =?UTF-8?q?=E7=94=A8=E4=BA=8E=E6=94=AF=E6=8C=81=E9=9D=9E=E6=B7=B7=E5=90=88?= =?UTF-8?q?=E7=B2=BE=E5=BA=A6=E4=BC=98=E5=8C=96=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/kj600/kj600/module_hook.py | 4 +- .../kj600/kj600/optimizer_collect.py | 61 +++++++++++++++---- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index 74ef684a61..ec15377342 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook from kj600.module_spec_verifier import get_config, validate_config_spec -from kj600.optimizer_collect import MixPrecsionOptimizerMon, print_rank_0, OptimizerMonFactory, MegatronDistributedOptimizerMon +from kj600.optimizer_collect import OptimizerMon, MixPrecsionOptimizerMon, print_rank_0, OptimizerMonFactory, MegatronDistributedOptimizerMon from kj600.features import eff_rank, get_sign_matches from kj600.visualizer import HeatmapVisualizer from kj600.anomaly_detect import AnomalyScanner, SummaryWriterWithAD @@ -172,7 +172,7 @@ class TrainerMon: @staticmethod def set_wrapped_optimizer(_wrapped_optimizer): - MixPrecsionOptimizerMon.set_wrapped_optimizer(_wrapped_optimizer) + OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer) @staticmethod def adhoc_check(target_tensor:torch.tensor, module_name:str, tensor_name:str, rank_list, ops_list): diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 285f17ca6d..b274932b39 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from collections import defaultdict import torch import torch.distributed as dist @@ -12,20 +13,26 @@ def print_rank_0(message, debug=False, force=False): print(message) -class MixPrecsionOptimizerMon: +class OptimizerMon(ABC): wrapped_optimizer = None + @classmethod + def set_wrapped_optimizer(cls, wrapped_optimizer): + cls.wrapped_optimizer = wrapped_optimizer + + @abstractmethod + def fetch_mv(self, monitor, torch_opt, params2name): + pass + + +class MixPrecsionOptimizerMon(OptimizerMon): def __init__(self) -> None: self.fp16_to_fp32_param = {} - @staticmethod - def set_wrapped_optimizer(_wrapped_optimizer): - MixPrecsionOptimizerMon.wrapped_optimizer = _wrapped_optimizer - # parameter tensors we want to monitor and their names are in params2name_dict # base_optimizer is pytorch optimizer, wrapped_optimizer is a normal object with base_optimizer def fetch_mv(self, monitor, torch_opt, params2name): - mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer + mix_prec_opt = self.wrapped_optimizer if not self.fp16_to_fp32_param and mix_prec_opt is not None: for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups): @@ -42,7 +49,7 @@ class MixPrecsionOptimizerMon: for param, name in params2name.items(): if param in self.fp16_to_fp32_param: param = self.fp16_to_fp32_param[param] - + if param in torch_opt.state: exp_avg = torch_opt.state[param]["exp_avg"] exp_avg_sq = torch_opt.state[param]["exp_avg_sq"] @@ -61,18 +68,44 @@ class MixPrecsionOptimizerMon: class MegatronDistributedOptimizerMon(MixPrecsionOptimizerMon): def fetch_mv(self, monitor, torch_opt, params2name): - mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer - if not (hasattr(mix_prec_opt, "model_float16_groups") and hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")): + mix_prec_opt = self.wrapped_optimizer + if not (hasattr(mix_prec_opt, "model_float16_groups") and hasattr(mix_prec_opt, + "shard_fp32_from_float16_groups")): raise Exception("megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, \ if not, please check megatron-lm version") if not self.fp16_to_fp32_param and mix_prec_opt is not None: - for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups, mix_prec_opt.shard_fp32_from_float16_groups): + for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups, + mix_prec_opt.shard_fp32_from_float16_groups): for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group): self.fp16_to_fp32_param[fp16_param] = shard_fp32_param return self._fetch_mv_in_adam(params2name, torch_opt, monitor) +class MegatronFP32OptimizerMon(OptimizerMon): + def fetch_mv(self, monitor, torch_opt, params2name): + exp_avg_dict = defaultdict(float) + exp_avg_sq_dict = defaultdict(float) + update_dict = defaultdict() + ratio_dict = defaultdict() + + for param, name in params2name.items(): + if param in torch_opt.state: + exp_avg = torch_opt.state[param]["exp_avg"] + exp_avg_sq = torch_opt.state[param]["exp_avg_sq"] + if monitor.mv_distribution: + exp_avg_dict[name] = exp_avg + exp_avg_sq_dict[name] = exp_avg_sq + if monitor.mg_direction: + exp_avg_dict[name] = exp_avg + if monitor.ur_distribution: + update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps']) + ratio_dict[name] = exp_avg / torch.sqrt(exp_avg_sq) + monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) + monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) + return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict + + class DummyOptimizerMon(MixPrecsionOptimizerMon): def fetch_mv(self, monitor, torch_opt, params2name): return None, None, None, None @@ -80,11 +113,15 @@ class DummyOptimizerMon(MixPrecsionOptimizerMon): class OptimizerMonFactory: @staticmethod - def create_optimizer_mon(opt_ty:str): + def create_optimizer_mon(opt_ty: str): if opt_ty == "Megatron_Float16OptimizerWithFloat16Params": return MixPrecsionOptimizerMon() if opt_ty == "Megatron_DistributedOptimizer": return MegatronDistributedOptimizerMon() + if opt_ty == "Megatron_FP32Optimizer": + return MegatronFP32OptimizerMon() if opt_ty is None or opt_ty == "unknown": return DummyOptimizerMon() - raise Exception("opt_ty should be Megatron_Float16OptimizerWithFloat16Params or Megatron_DistributedOptimizer or None or unknown") + raise Exception( + "opt_ty should be Megatron_Float16OptimizerWithFloat16Params or Megatron_DistributedOptimizer or " + "Megatron_FP32Optimizer or None or unknown") -- Gitee