diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index d21af4b82dc36c0d5dd42153808a004f9b9514aa..559c1df73c27cf90d10be4e9712fc93a8af95f7e 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -39,6 +39,7 @@ class DataCollector: self.module_count = {} self.scope = ScopeFactory(self.config).build_scope() self.backward_module_names = {} + self.optimizer_status = "" atexit.register(self.write_json) @property @@ -144,7 +145,12 @@ class DataCollector: def update_construct(self, name): if self.config.level not in DataCollector.level_without_construct: - self.data_writer.update_construct({name: self.module_processor.api_parent_node}) + if self.optimizer_status == "clip_grad": + self.data_writer.update_construct({name: "clip_grad"}) + elif self.optimizer_status == "optimizer": + self.data_writer.update_construct({name: "optimizer"}) + else: + self.data_writer.update_construct({name: self.module_processor.api_parent_node}) self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, flush=False): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index f2538370118efaea3d4a00288c8a6ba953dbb2f2..ac8fa1cd5e60c8a6cc9d4d9095be2ec51864057c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -32,6 +32,8 @@ from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_registry import api_register from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.module_processer import ModuleProcesser +from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if torch_version_above_or_equal_2: @@ -281,6 +283,29 @@ class Service: ) api_register.api_modularity() + if self.config.level == "mix": + def optimizer_pre_step_hook(optimizer, args, kwargs): + self.data_collector.optimizer_status = "optimizer" + + def optimizer_post_step_hook(optimizer, args, kwargs): + self.data_collector.optimizer_status = "end_optimizer" + + def patch_clip_grad(func): + def wrapper(*args, **kwargs): + self.data_collector.optimizer_status = "clip_grad" + func(*args, **kwargs) + self.data_collector.optimizer_status = "end_clip_grad" + return wrapper + + register_optimizer_step_pre_hook(optimizer_pre_step_hook) + register_optimizer_step_post_hook(optimizer_post_step_hook) + try: + from megatron.core.optimizer import MegatronOptimizer + MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm) + except ImportError: + logger.warning_on_rank_0("Fail to patch megatron clip grad function.") + + def attl_init(self): if self.config.online_run_ut: from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL @@ -311,7 +336,7 @@ class Service: elif self.attl.socket_manager is not None: logger.info(f"pid: {os.getpid()} finished, start send STOP signal.") self.attl.socket_manager.send_stop_signal() - + def reset_status(self): ModuleProcesser.reset_module_stats() HOOKModule.reset_module_stats() @@ -319,7 +344,7 @@ class Service: if self.config.level == Const.LEVEL_L2: self.data_collector.data_processor.reset_status() - return + return if self.config.step and self.current_iter not in self.config.step: return if self.config.rank and self.current_rank not in self.config.rank: