From 356fc6fc19896baf2d864bc7197c7d3ead70fb73 Mon Sep 17 00:00:00 2001 From: qianggee Date: Mon, 23 Dec 2024 02:26:29 +0000 Subject: [PATCH] solve conflict --- .../msprobe/pytorch/monitor/module_hook.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 23a9dec57..9081329a2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -817,11 +817,14 @@ class TrainerMon: return hooked_count def _patch_grad_sync(self): + if not self.wg_distribution: + return + def patch_sync(sync_grad_func): def wrapper(bucket): grad_dict = {} for param, name in self.param2name.items(): - if param not in bucket.params_list: + if not param_in_bucket(param, bucket): continue grad = param.main_grad if self.params_have_main_grad else param.grad if grad is None: @@ -837,19 +840,28 @@ class TrainerMon: return out return wrapper - + try: from megatron.core.distributed.param_and_grad_buffer import Bucket + Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) self.enable_megatron = True + logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0") + def param_in_bucket(param, bucket): + return param in bucket.params except ImportError: self.enable_megatron = False - if not self.wg_distribution: - return + try: + from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup + _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync) + self.enable_megatron = True + logger.info("megatron version is > core_r0.8.0") + def param_in_bucket(param, bucket_group): + return param in bucket_group.param_to_bucket + except ImportError: + self.enable_megatron = False - if self.enable_megatron: - Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version - else: + if not self.enable_megatron: self._hook_weights() def _hook_weights(self): @@ -866,6 +878,7 @@ class TrainerMon: else: context_dict[key] = param.grad.clone() + logger.info('hooking weights') for param, name in self.param2name.items(): key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) setattr(param, 'micro_step', 0) -- Gitee