From 6a5bae437b08ca421d8d7a0edcf26a4863819251 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Tue, 22 Oct 2024 09:04:53 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=8C=87=E5=AE=9Atargets=E6=97=B6=E6=8A=A5?= =?UTF-8?q?=E9=94=99=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/module_hook.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index c26750f78..54f0ec8ad 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -57,8 +57,11 @@ class ModuleHookContext: def set_format_by_arg(self, key_name:str, target_config:dict): cared = target_config.get(self.module_name, self.struct) if key_name in cared: - config = cared[key_name].get('config') - self.format_by_arg[key_name] = config if config else cared[key_name] + if isinstance(cared[key_name], dict): # cared = self.struct + config = cared[key_name].get('config') + self.format_by_arg[key_name] = config + else: # cared = target_config[self.module_name] + self.format_by_arg[key_name] = cared[key_name] elif key_name in ['input', 'input_grad']: self.ignore_in = True -- Gitee From 3e41e33264d079c913f81df783a3e63fe162add5 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Tue, 22 Oct 2024 09:04:53 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=8C=87=E5=AE=9Atargets=E6=97=B6=E6=8A=A5?= =?UTF-8?q?=E9=94=99=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/module_hook.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index 4d31c9ca0..c594c1af2 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -58,8 +58,11 @@ class ModuleHookContext: def set_format_by_arg(self, key_name:str, target_config:dict): cared = target_config.get(self.module_name, self.struct) if key_name in cared: - config = cared[key_name].get('config') - self.format_by_arg[key_name] = config if config else cared[key_name] + if isinstance(cared[key_name], dict): # cared = self.struct + config = cared[key_name].get('config') + self.format_by_arg[key_name] = config + else: # cared = target_config[self.module_name] + self.format_by_arg[key_name] = cared[key_name] elif key_name in ['input', 'input_grad']: self.ignore_in = True -- Gitee From c9df4a62863b817cf423dec9fd63533fe9e5c8ac Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Wed, 23 Oct 2024 10:50:18 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9=E4=BA=8Emindspe?= =?UTF-8?q?ed=E9=87=8D=E8=AE=A1=E7=AE=97=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/kj600/kj600/module_hook.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index c594c1af2..7e8c9054b 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -653,9 +653,12 @@ class TrainerMon: def _is_recomputation(): """Check if the current operation is in the recomputation phase. - This function inspects the current call stack to determine if the 'backward' function is being - executed and if the execution is taking place within the 'torch/autograd/function.py' file. - If both conditions are met, it indicates that the current operation is in the recomputation phase. + This function inspects the current call stack to indicate whether the current operation is in the + recomputation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework. + megatron: The 'backward' function is called by the 'torch/autograd/function.py' file. + mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py' + file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the + 'torch/_tensor.py' file. Returns: bool: True if in the recomputation phase, False otherwise. @@ -663,9 +666,14 @@ class TrainerMon: backward_function_indices = [] call_stack = inspect.stack() - # Identify indices in the call stack where the 'backward' function is being executed + # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file. + for frame_info in call_stack: + if frame_info.function == 'backward' and frame_info.filename.endswith('torch/_tensor.py'): + return True + + # Identify indices in the call stack where the specific function is being executed for idx, frame_info in enumerate(call_stack): - if frame_info.function == 'backward': + if frame_info.function == 'backward' or frame_info.function == 'checkpoint_function_backward': backward_function_indices.append(idx) # Check if the execution is within 'torch/autograd/function.py' file -- Gitee