From ec611c3af4c58351cacbacf4eee917ed08e91b16 Mon Sep 17 00:00:00 2001 From: makai Date: Fri, 20 Sep 2024 11:11:30 +0800 Subject: [PATCH 1/3] check step --- debug/accuracy_tools/msprobe/core/common_config.py | 5 +++-- .../msprobe/pytorch/debugger/precision_debugger.py | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py index 1479a018e..da95b138c 100644 --- a/debug/accuracy_tools/msprobe/core/common_config.py +++ b/debug/accuracy_tools/msprobe/core/common_config.py @@ -36,7 +36,8 @@ class CommonConfig: f'For the hyphen(-) in step, the left boundary ({borderline[0]}) cannot be greater than the right boundary ({borderline[1]}).') return continual_step - def get_real_step(self, step_input): + @classmethod + def get_real_step(cls, step_input): if step_input is None: return [] if not isinstance(step_input, list): @@ -49,7 +50,7 @@ class CommonConfig: if isinstance(step, int) and step >= 0: real_step.append(step) elif isinstance(step, str) and '-' in step: - continual_step = self.get_step_from_string(step) + continual_step = cls.get_step_from_string(step) real_step.extend(continual_step) real_step = list(set(real_step)) real_step.sort() diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 5c11f4931..82095bce9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -4,6 +4,7 @@ import torch from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import FileChecker +from msprobe.core.common_config import CommonConfig from msprobe.pytorch.common.log import logger from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor @@ -99,6 +100,10 @@ class PrecisionDebugger: MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module" ) + @staticmethod + def get_real_step(step): + CommonConfig.get_real_step(step) + @classmethod def start(cls): instance = cls._instance -- Gitee From 80ba4a3f4941dbf3245474942add81cacdd6ec71 Mon Sep 17 00:00:00 2001 From: makai Date: Fri, 20 Sep 2024 11:14:14 +0800 Subject: [PATCH 2/3] check step --- .../msprobe/pytorch/debugger/precision_debugger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 82095bce9..f8a8725d6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -53,7 +53,7 @@ class PrecisionDebugger: self.gm = GradientMonitor(common_config, task_config) return if step: - common_config.step = step + common_config.step = self.get_real_step(step) self.config = DebuggerConfig( common_config, task_config, task, dump_path, level ) -- Gitee From 72f0a1a36ea9b2f4da0bc981de1291fe2f783140 Mon Sep 17 00:00:00 2001 From: makai Date: Fri, 20 Sep 2024 11:21:29 +0800 Subject: [PATCH 3/3] check step --- .../msprobe/pytorch/debugger/precision_debugger.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index f8a8725d6..54c79d6a6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -53,7 +53,7 @@ class PrecisionDebugger: self.gm = GradientMonitor(common_config, task_config) return if step: - common_config.step = self.get_real_step(step) + common_config.step = CommonConfig.get_real_step(step) self.config = DebuggerConfig( common_config, task_config, task, dump_path, level ) @@ -100,10 +100,6 @@ class PrecisionDebugger: MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module" ) - @staticmethod - def get_real_step(step): - CommonConfig.get_real_step(step) - @classmethod def start(cls): instance = cls._instance -- Gitee