diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py index 1479a018e5281adf2ddcd8a916fbc299acc6f784..da95b138c157063a521ec3d83c2a64208dde67cf 100644 --- a/debug/accuracy_tools/msprobe/core/common_config.py +++ b/debug/accuracy_tools/msprobe/core/common_config.py @@ -36,7 +36,8 @@ class CommonConfig: f'For the hyphen(-) in step, the left boundary ({borderline[0]}) cannot be greater than the right boundary ({borderline[1]}).') return continual_step - def get_real_step(self, step_input): + @classmethod + def get_real_step(cls, step_input): if step_input is None: return [] if not isinstance(step_input, list): @@ -49,7 +50,7 @@ class CommonConfig: if isinstance(step, int) and step >= 0: real_step.append(step) elif isinstance(step, str) and '-' in step: - continual_step = self.get_step_from_string(step) + continual_step = cls.get_step_from_string(step) real_step.extend(continual_step) real_step = list(set(real_step)) real_step.sort() diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 5c11f493125b9cc85cff9523e6da3903f89cdec5..54c79d6a66008594f8c23417988fa3c4013e8956 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -4,6 +4,7 @@ import torch from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import FileChecker +from msprobe.core.common_config import CommonConfig from msprobe.pytorch.common.log import logger from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor @@ -52,7 +53,7 @@ class PrecisionDebugger: self.gm = GradientMonitor(common_config, task_config) return if step: - common_config.step = step + common_config.step = CommonConfig.get_real_step(step) self.config = DebuggerConfig( common_config, task_config, task, dump_path, level )