diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 0158b3713758a10ff7470793198d823c4eed773b..457756321eee7fffbec49aefff5ef78e96fd73e7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -159,7 +159,6 @@ class TrainerMon: self.params_have_main_grad = params_have_main_grad self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) - self.origin_start_grad_sync = None self.fsdp_post_backward_hook = None self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开 self.config = load_json(config_file_path) @@ -195,6 +194,7 @@ class TrainerMon: self.dp_group = None self.tp_group = None self.enable_megatron = False + self.enable_deepspeed = False self.fsdp_wrapped_module = False self.micro_batch_number = 1 self.optimizer_mon = None @@ -804,20 +804,9 @@ class TrainerMon: bwd_context.reset() self.grad_context.reset() # 权重梯度和激活值梯度都在这 - if self.origin_start_grad_sync: # megatron - try: - from megatron.core.distributed.param_and_grad_buffer import Bucket - Bucket.start_grad_sync = self.origin_start_grad_sync - logger.info("remove Bucket start_grad_sync") - except ImportError: - pass - try: - from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup - _ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync - logger.info("remove _ParamAndGradBucketGroup start_grad_sync") - except ImportError: - pass - elif self.fsdp_post_backward_hook: # fsdp + + self.optimizer_mon.restore_grad_sync(self) + if self.fsdp_post_backward_hook: # fsdp torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook logger.info("remove patch_post_backward_hook in fsdp.") else: # not megatron and not fsdp @@ -896,7 +885,6 @@ class TrainerMon: squash_name = prefix + squash_param_name(param_name, self.squash_name) for target in self.config['targets'].keys(): if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target): - setattr(param, "zero_out_wgrad", True) return True return False @@ -1059,31 +1047,6 @@ class TrainerMon: return hooked_count def _patch_grad_sync(self): - def patch_sync(sync_grad_func): - def wrapper(bucket): - grad_dict = {} - # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket. - # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup. - # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup. - bucket_params_id_list = [id(params) for params in bucket.params] - for param, name in self.param2name.items(): - if id(param) not in bucket_params_id_list: - continue - grad = param.main_grad if self.params_have_main_grad else param.grad - if grad is None: - logger.warning(f"grad is None: {name}, maybe something wrong happened.") - continue - tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) - if tag is None: - continue - grad_dict[tag] = grad - self.register_param_call_id("sync_grad_func", tag) - get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) - out = sync_grad_func(bucket) - return out - - return wrapper - if not self.wg_distribution: return if self.fsdp_wrapped_module: @@ -1094,24 +1057,10 @@ class TrainerMon: if self.monitor_mbs_grad: self._hook_weights() return - try: - from megatron.core.distributed.param_and_grad_buffer import Bucket - self.origin_start_grad_sync = Bucket.start_grad_sync - Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) - self.enable_megatron = True - logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0") - except ImportError: - self.enable_megatron = False + + self.optimizer_mon.patch_grad_sync(self) - try: - from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup - self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync - _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync) - self.enable_megatron = True - logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") - except ImportError: - self.enable_megatron = False | self.enable_megatron - if self.enable_megatron: + if self.enable_megatron or self.enable_deepspeed: return # default hook weights diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 8a63eaef9c348663c0bc1084b6415050dc90935e..0678162a1137f1ade810194cb3a330c145061fe3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -18,6 +18,7 @@ import torch from msprobe.pytorch.common.log import logger from msprobe.core.monitor.utils import MVResult +from msprobe.pytorch.monitor.module_metric import get_metrics from msprobe.core.common.const import MonitorConst @@ -26,6 +27,8 @@ class OptimizerMon(object): self.fp16_to_fp32_param = {} self.torch_opt = torch_opt self.state = {} + self.origin_funcs = [] + self.bucket_class = None def narrow_from_flatten(self, param, flatten_state): return flatten_state @@ -120,6 +123,59 @@ class OptimizerMon(object): monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict) + def patch_grad_sync(self, monitor): + def patch_sync(sync_grad_func): + def wrapper(bucket): + grad_dict = {} + # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket. + # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup. + # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup. + bucket_params_id_list = [id(params) for params in bucket.params] + for param, name in monitor.param2name.items(): + if id(param) not in bucket_params_id_list: + continue + grad = param.main_grad if monitor.params_have_main_grad else param.grad + if grad is None: + logger.warning(f"grad is None: {name}, maybe something wrong happened.") + continue + tag = monitor.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) + if tag is None: + continue + grad_dict[tag] = grad + monitor.register_param_call_id("sync_grad_func", tag) + get_metrics(monitor.ops, grad_dict, monitor.eps, monitor.grad_context.pre) + out = sync_grad_func(bucket) + return out + + return wrapper + + try: + from megatron.core.distributed.param_and_grad_buffer import Bucket + self.origin_funcs.append(Bucket.start_grad_sync) + self.bucket_class = Bucket + Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) + monitor.enable_megatron = True + logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0") + except ImportError: + monitor.enable_megatron = False + + try: + from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup + self.origin_funcs.append(_ParamAndGradBucketGroup.start_grad_sync) + self.bucket_class = _ParamAndGradBucketGroup + _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync) + monitor.enable_megatron = True + logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") + except ImportError: + monitor.enable_megatron = False | monitor.enable_megatron + + def restore_grad_sync(self, monitor): + if not monitor.enable_megatron: + return + + self.bucket_class.start_grad_sync = self.origin_funcs[0] + + def _get_single_state(self, torch_opt): state = {} if hasattr(torch_opt, 'param_to_cpu_states_map'): @@ -131,7 +187,7 @@ class OptimizerMon(object): self.state.update(state) -class MixPrecisionOptimizerMon(OptimizerMon): +class MegatronMixPrecisionOptimizerMon(OptimizerMon): """ 混合精度优化器监控类。在混合精度训练中监控和管理优化器。 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。 @@ -161,7 +217,7 @@ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): super().map_fp16_to_fp32_param(opt) -class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): +class MegatronChainedMixPrecisionOptimizerMon(MegatronMixPrecisionOptimizerMon): def map_fp16_to_fp32_param(self, torch_opt): for opt in torch_opt.chained_optimizers: super().map_fp16_to_fp32_param(opt) @@ -248,6 +304,12 @@ class DeepSpeedZeroOptimizerMon(OptimizerMon): grad_dict[tag] = grad return grad_dict + + def patch_grad_sync(self, monitor): + pass + + def restore_grad_sync(self, monitor): + pass class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon): @@ -291,6 +353,47 @@ class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon): break + def patch_grad_sync(self, monitor): + def patch_sync(reduce_func): + def wrapper(zero_optimizer, *args, **kwargs): + grad_dict = {} + for i, param, _ in zero_optimizer.params_in_ipg_bucket: + if isinstance(param, int): # for ds >= 0.17.0 + param = zero_optimizer.bit16_groups[i][param] + name = monitor.param2name[param] + tag = monitor.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) + grad_dict[tag] = zero_optimizer.get_gradient_for_reduction(param) + monitor.register_param_call_id("sync_grad_func", tag) + get_metrics(monitor.ops, grad_dict, monitor.eps, monitor.grad_context.pre) + out = reduce_func(zero_optimizer, *args, **kwargs) + return out + + return wrapper + try: + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + self.origin_funcs = [ + DeepSpeedZeroOptimizer.average_tensor, + DeepSpeedZeroOptimizer.buffered_reduce_fallback + ] + DeepSpeedZeroOptimizer.average_tensor = patch_sync(DeepSpeedZeroOptimizer.average_tensor) + DeepSpeedZeroOptimizer.buffered_reduce_fallback = \ + patch_sync(DeepSpeedZeroOptimizer.buffered_reduce_fallback) + monitor.enable_deepspeed = True + logger.info('deepspeed enabled') + except Exception as e: + monitor.enable_deepspeed = False | monitor.enable_deepspeed + logger.warning('Seems using deepspeed zero 1 or 2. But patch average tensor failed') + + def restore_grad_sync(self, monitor): + if not monitor.enable_deepspeed: + return + + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + DeepSpeedZeroOptimizer.average_tensor = self.origin_funcs[0] + DeepSpeedZeroOptimizer.buffered_reduce_fallback = self.origin_funcs[1] + + + class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon): def __init__(self, torch_opt): super().__init__(torch_opt) @@ -314,7 +417,7 @@ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon): class OptimizerMonFactory: _optimizer_mon_map = { "FP32Optimizer": OptimizerMon, - "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, + "Float16OptimizerWithFloat16Params": MegatronMixPrecisionOptimizerMon, "DistributedOptimizer": MegatronDistributedOptimizerMon, "SwapDistributedOptimizer": MegatronDistributedOptimizerMon, "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon, diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py index c7cbd86bbcc8671b88aa8a7c97e24181c5a42379..e8cbd00a0f31589104a50340f990558bd0277be9 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py @@ -5,7 +5,7 @@ from unittest.mock import Mock, patch, MagicMock import torch from msprobe.core.common.const import MonitorConst from msprobe.pytorch.monitor.optimizer_collect import OptimizerMon, \ - OptimizerMonFactory, MixPrecisionOptimizerMon, MegatronDistributedOptimizerMon, \ + OptimizerMonFactory, MegatronMixPrecisionOptimizerMon, MegatronDistributedOptimizerMon, \ MegatronChainedDistributedOptimizerMon, MegatronChainedMixPrecisionOptimizerMon, \ DeepSpeedZeroOptimizerMon, DeepSpeedZeroOptimizerStage0Mon, \ DeepSpeedZeroOptimizerStage1or2Mon, DeepSpeedZeroOptimizerStage3Mon @@ -84,7 +84,7 @@ class TestMixPrecisionOptimizerMon(unittest.TestCase): self.mix_prec_opt = MagicMock() self.mix_prec_opt.float16_groups = [MagicMock()] self.mix_prec_opt.fp32_from_float16_groups = [MagicMock()] - self.optimizer = MixPrecisionOptimizerMon(self.torch_opt) + self.optimizer = MegatronMixPrecisionOptimizerMon(self.torch_opt) self.optimizer.fp16_to_fp32_param = {} # Mock fetch_mv method and set a fixed return value @@ -372,7 +372,7 @@ class TestOptimizerMonFactory(unittest.TestCase): mix_optimizer_class.__name__ = "Float16OptimizerWithFloat16Params" mix_optimizer.__class__ = mix_optimizer_class self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(mix_optimizer), - MixPrecisionOptimizerMon) + MegatronMixPrecisionOptimizerMon) dis_optimizer = MagicMock() dis_optimizer_class = MagicMock() dis_optimizer_class.__name__ = "DistributedOptimizer"