diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 16d96e9f2fa961ccf731e2acc8afb12af8c831a8..53f970e0d9201f228cd479ce701add3a51681b5e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -232,6 +232,7 @@ class TrainerMon: self.dp_group = None self.tp_group = None self.enable_megatron = False + self.enable_deepspeed = False self.fsdp_wrapped_module = False self.micro_batch_number = 1 self.optimizer_mon = None @@ -1320,6 +1321,19 @@ class TrainerMon: return hooked_count def _patch_grad_sync(self): + def patch_average_tensor(average_tensor): + def wrapper(zero_optimizer, tensor): + grad_dict = {} + for i, param, param_id in zero_optimizer.params_in_ipg_bucket: + name = self.param2name[param] + tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) + grad_dict[tag] = zero_optimizer.get_gradient_for_reduction(param) + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) + out = average_tensor(zero_optimizer, tensor) + return out + + return wrapper + def patch_sync(sync_grad_func): def wrapper(bucket): grad_dict = {} @@ -1372,7 +1386,19 @@ class TrainerMon: logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") except ImportError: self.enable_megatron = False | self.enable_megatron - if self.enable_megatron: + + if self.optimizer_mon.torch_opt.__class__.__name__ == 'DeepSpeedZeroOptimizer': + try: + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + self.origin_start_grad_sync = DeepSpeedZeroOptimizer.average_tensor + DeepSpeedZeroOptimizer.average_tensor = patch_average_tensor(DeepSpeedZeroOptimizer.average_tensor) + self.enable_deepspeed = True | self.enable_deepspeed + logger.info('deepspeed enabled') + except Exception as e: + self.enable_deepspeed = False | self.enable_deepspeed + logger.warning('Seems using deepspeed zero 1 or 2. But patch average tensor failed') + + if self.enable_megatron or self.enable_deepspeed: return # default hook weights