From 1a9492e6f519ed172eeec87441993d71e079bd8c Mon Sep 17 00:00:00 2001 From: qiangge Date: Wed, 2 Jul 2025 09:30:09 +0800 Subject: [PATCH] patch zero12 average tensor to accelerate --- .../msprobe/pytorch/monitor/module_hook.py | 65 ++++++++++++++----- .../pytorch/monitor/optimizer_collect.py | 6 +- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index e34fb21d5ee..9fedd8e111b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -193,6 +193,7 @@ class TrainerMon: self.dp_group = None self.tp_group = None self.enable_megatron = False + self.enable_deepspeed = False self.fsdp_wrapped_module = False self.micro_batch_number = 1 self.optimizer_mon = None @@ -1058,6 +1059,21 @@ class TrainerMon: return hooked_count def _patch_grad_sync(self): + def patch_ds_wg_reduce(reduce_func): + def wrapper(zero_optimizer, *args, **kwargs): + grad_dict = {} + for i, param, param_id in zero_optimizer.params_in_ipg_bucket: + if isinstance(param, int): # for ds >= 0.17.0 + param = zero_optimizer.bit16_groups[i][param] # for ds >= 0.17.0 + name = self.param2name[param] + tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) + grad_dict[tag] = zero_optimizer.get_gradient_for_reduction(param) + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) + out = reduce_func(zero_optimizer, *args, **kwargs) + return out + + return wrapper + def patch_sync(sync_grad_func): def wrapper(bucket): grad_dict = {} @@ -1093,24 +1109,39 @@ class TrainerMon: if self.monitor_mbs_grad: self._hook_weights() return - try: - from megatron.core.distributed.param_and_grad_buffer import Bucket - self.origin_start_grad_sync = Bucket.start_grad_sync - Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) - self.enable_megatron = True - logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0") - except ImportError: - self.enable_megatron = False + + if 'Megatron' in self.optimizer_mon.torch_opt.__class__.__name__: + try: + from megatron.core.distributed.param_and_grad_buffer import Bucket + self.origin_start_grad_sync = Bucket.start_grad_sync + Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) + self.enable_megatron = True + logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0") + except ImportError: + self.enable_megatron = False - try: - from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup - self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync - _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync) - self.enable_megatron = True - logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") - except ImportError: - self.enable_megatron = False | self.enable_megatron - if self.enable_megatron: + try: + from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup + self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync + _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync) + self.enable_megatron = True + logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") + except ImportError: + self.enable_megatron = False | self.enable_megatron + + if self.optimizer_mon.torch_opt.__class__.__name__ == 'DeepSpeedZeroOptimizer': # stage 1 or 2 + try: + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + self.origin_start_grad_sync = DeepSpeedZeroOptimizer.average_tensor + DeepSpeedZeroOptimizer.average_tensor = patch_ds_wg_reduce(DeepSpeedZeroOptimizer.average_tensor) + DeepSpeedZeroOptimizer.buffered_reduce_fallback = patch_ds_wg_reduce(DeepSpeedZeroOptimizer.buffered_reduce_fallback) + self.enable_deepspeed = True | self.enable_deepspeed + logger.info('deepspeed enabled') + except Exception as e: + self.enable_deepspeed = False | self.enable_deepspeed + logger.warning('Seems using deepspeed zero 1 or 2. But patch average tensor failed') + + if self.enable_megatron or self.enable_deepspeed: return # default hook weights diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 3f8140cb7dc..593e8462679 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -128,7 +128,7 @@ class OptimizerMon(object): self.state.update(state) -class MixPrecisionOptimizerMon(OptimizerMon): +class MegatronMixPrecisionOptimizerMon(OptimizerMon): """ 混合精度优化器监控类。在混合精度训练中监控和管理优化器。 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。 @@ -158,7 +158,7 @@ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): super().map_fp16_to_fp32_param(opt) -class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): +class MegatronChainedMixPrecisionOptimizerMon(MegatronMixPrecisionOptimizerMon): def map_fp16_to_fp32_param(self, torch_opt): for opt in torch_opt.chained_optimizers: super().map_fp16_to_fp32_param(opt) @@ -311,7 +311,7 @@ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon): class OptimizerMonFactory: _optimizer_mon_map = { "FP32Optimizer": OptimizerMon, - "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, + "Float16OptimizerWithFloat16Params": MegatronMixPrecisionOptimizerMon, "DistributedOptimizer": MegatronDistributedOptimizerMon, "SwapDistributedOptimizer": MegatronDistributedOptimizerMon, "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon, -- Gitee