diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py index 688734be8acbfae23ff453fea67787a577effa13..6b8ebdf116e7121f0b491e78c8d5d7c34efd5d6a 100644 --- a/debug/accuracy_tools/msprobe/core/common_config.py +++ b/debug/accuracy_tools/msprobe/core/common_config.py @@ -14,6 +14,7 @@ class CommonConfig: self.acl_config = json_config.get('acl_config') self.is_deterministic = json_config.get('is_deterministic', False) self.enable_dataloader = json_config.get('enable_dataloader', False) + self.enable_optimizer = json_config.get('enable_optimizer', False) self._check_config() def _check_config(self): diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py index 7c32be7cc34e1f90334655a1ac03fbd1980b2228..5d94656e4d6ccb2dd645ea0459c57f102a8994be 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py @@ -13,6 +13,7 @@ class DebuggerConfig: self.seed = common_config.seed if common_config.seed else 1234 self.is_deterministic = common_config.is_deterministic self.enable_dataloader = common_config.enable_dataloader + self.enable_optimizer = common_config.enable_optimizer self.scope = task_config.scope if task_config.scope else [] self.list = task_config.list if task_config.list else [] self.data_mode = task_config.data_mode if task_config.data_mode else ["all"] diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 8433f0af695bf7afb08f32d99fc6f58788ab8b97..2e3889e5299f5394a4a83e10bcdd230113346b6b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -28,6 +28,7 @@ class PrecisionDebugger: level=None, model=None, step=None, + enable_optimizer=None ): if not hasattr(self, "initialized"): self.api_origin = False @@ -40,6 +41,8 @@ class PrecisionDebugger: return if step: common_config.step = step + if enable_optimizer is not None: + common_config.enable_optimizer = enable_optimizer self.config = DebuggerConfig( common_config, task_config, task, dump_path, level ) @@ -49,6 +52,9 @@ class PrecisionDebugger: if self.enable_dataloader: logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) + self.enable_optimizer = self.config.enable_optimizer + if self.enable_optimizer: + logger.info_on_rank_0("The enable_optimizer is on and start()/stop()/step() will be skipped.") @property def instance(self): @@ -102,6 +108,14 @@ class PrecisionDebugger: raise Exception("PrecisionDebugger instance is not created.") cls._instance.service.step() + @classmethod + def hook_optimizer(cls): + instance = cls._instance + if not instance: + raise Exception("PrecisionDebugger instance is not created.") + else: + instance.service.hook_optimizer(instance.model) + @classmethod def monitor(cls, model): if not cls._instance: diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 980c7d840cae208aae5c137d6a93e89279b580ca..1221ab7327a269ff4ab5042c86d6d5a39964b963 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -5,6 +5,7 @@ from pathlib import Path from collections import namedtuple import torch +from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException from msprobe.core.common.file_check import FileChecker, check_path_before_create @@ -110,6 +111,18 @@ class Service: forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2, forward_name_template) return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) + def hook_optimizer(self, model): + def optimizer_pre_step_hook(optimizer, args, kargs): + self.stop() + self.step() + + def optimizer_post_step_hook(optimizer, args, kargs): + self.start(model) + + register_optimizer_step_pre_hook(optimizer_pre_step_hook) + register_optimizer_step_post_hook(optimizer_post_step_hook) + self.start(model) + def step(self): self.current_iter += 1 self.data_collector.update_iter(self.current_iter)