From 63acab07c96ce110b66a8d9475c6c82ff16d5783 Mon Sep 17 00:00:00 2001 From: lcw Date: Sat, 30 Nov 2024 09:54:59 +0800 Subject: [PATCH 1/2] =?UTF-8?q?optimizer=E6=9B=B4=E6=96=B0=E8=87=B3constru?= =?UTF-8?q?ct=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/core/data_dump/data_collector.py | 8 +++++- .../accuracy_tools/msprobe/pytorch/service.py | 28 +++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index d21af4b82dc..559c1df73c2 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -39,6 +39,7 @@ class DataCollector: self.module_count = {} self.scope = ScopeFactory(self.config).build_scope() self.backward_module_names = {} + self.optimizer_status = "" atexit.register(self.write_json) @property @@ -144,7 +145,12 @@ class DataCollector: def update_construct(self, name): if self.config.level not in DataCollector.level_without_construct: - self.data_writer.update_construct({name: self.module_processor.api_parent_node}) + if self.optimizer_status == "clip_grad": + self.data_writer.update_construct({name: "clip_grad"}) + elif self.optimizer_status == "optimizer": + self.data_writer.update_construct({name: "optimizer"}) + else: + self.data_writer.update_construct({name: self.module_processor.api_parent_node}) self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, flush=False): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 8fccd612e60..3344b1a25b5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -32,6 +32,7 @@ from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_ from msprobe.pytorch.hook_module.api_registry import api_register from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.module_processer import ModuleProcesser +from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if torch_version_above_or_equal_2: @@ -323,6 +324,29 @@ class Service: self.config.online_run_ut) api_register.api_modularity() + if self.config.level == "mix": + def optimizer_pre_step_hook(optimizer, args, kwargs): + self.data_collector.optimizer_status = "optimizer" + + def optimizer_post_step_hook(optimizer, args, kwargs): + self.data_collector.optimizer_status = "end_optimizer" + + def patch_clip_grad(func): + def wrapper(*args, **kwargs): + self.data_collector.optimizer_status = "clip_grad" + func(*args, **kwargs) + self.data_collector.optimizer_status = "end_clip_grad" + return wrapper + + register_optimizer_step_pre_hook(optimizer_pre_step_hook) + register_optimizer_step_post_hook(optimizer_post_step_hook) + try: + from megatron.core.optimizer import MegatronOptimizer + MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm) + except ImportError: + logger.warning_on_rank_0("Fail to patch megatron clip grad function.") + + def attl_init(self): if self.config.online_run_ut: from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL @@ -353,7 +377,7 @@ class Service: elif self.attl.socket_manager is not None: logger.info(f"pid: {os.getpid()} finished, start send STOP signal.") self.attl.socket_manager.send_stop_signal() - + def reset_status(self): ModuleProcesser.reset_module_stats() HOOKModule.reset_module_stats() @@ -361,7 +385,7 @@ class Service: if self.config.level == Const.LEVEL_L2: self.data_collector.data_processor.reset_status() - return + return if self.config.step and self.current_iter not in self.config.step: return if self.config.rank and self.current_rank not in self.config.rank: -- Gitee From 5024ce9719d477c0e1659c538cb4835a62f4ad43 Mon Sep 17 00:00:00 2001 From: lcw Date: Sat, 30 Nov 2024 09:54:59 +0800 Subject: [PATCH 2/2] =?UTF-8?q?optimizer=E6=9B=B4=E6=96=B0=E8=87=B3constru?= =?UTF-8?q?ct=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/core/data_dump/data_collector.py | 8 +++++- .../accuracy_tools/msprobe/pytorch/service.py | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index d21af4b82dc..559c1df73c2 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -39,6 +39,7 @@ class DataCollector: self.module_count = {} self.scope = ScopeFactory(self.config).build_scope() self.backward_module_names = {} + self.optimizer_status = "" atexit.register(self.write_json) @property @@ -144,7 +145,12 @@ class DataCollector: def update_construct(self, name): if self.config.level not in DataCollector.level_without_construct: - self.data_writer.update_construct({name: self.module_processor.api_parent_node}) + if self.optimizer_status == "clip_grad": + self.data_writer.update_construct({name: "clip_grad"}) + elif self.optimizer_status == "optimizer": + self.data_writer.update_construct({name: "optimizer"}) + else: + self.data_writer.update_construct({name: self.module_processor.api_parent_node}) self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, flush=False): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index f2538370118..91061688e4d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -32,6 +32,8 @@ from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_registry import api_register from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.module_processer import ModuleProcesser +from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if torch_version_above_or_equal_2: @@ -281,6 +283,29 @@ class Service: ) api_register.api_modularity() + if self.config.level == "mix": + def optimizer_pre_step_hook(optimizer, args, kwargs): + self.data_collector.optimizer_status = "optimizer" + + def optimizer_post_step_hook(optimizer, args, kwargs): + self.data_collector.optimizer_status = "end_optimizer" + + def patch_clip_grad(func): + def wrapper(*args, **kwargs): + self.data_collector.optimizer_status = "clip_grad" + func(*args, **kwargs) + self.data_collector.optimizer_status = "end_clip_grad" + return wrapper + + register_optimizer_step_pre_hook(optimizer_pre_step_hook) + register_optimizer_step_post_hook(optimizer_post_step_hook) + try: + from megatron.core.optimizer import MegatronOptimizer + MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm) + except ImportError: + logger.warning_on_rank_0("Fail to patch megatron clip grad function.") + + def attl_init(self): if self.config.online_run_ut: from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL -- Gitee