From 09c3a57b7f29a6e44b3e16980af17780ebe0cd2c Mon Sep 17 00:00:00 2001 From: qianggee Date: Thu, 21 Nov 2024 03:04:55 +0000 Subject: [PATCH] submodule only for xy --- .../msprobe/pytorch/monitor/module_hook.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index e9807db84..d7d952dc7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -19,12 +19,10 @@ import uuid import json from collections import defaultdict from functools import partial -from copy import deepcopy from datetime import datetime import torch import torch.distributed as dist -from torch.utils.hooks import BackwardHook from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec @@ -354,25 +352,6 @@ class TrainerMon: logger.info_on_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") - def clone_if_tensor(args): - if isinstance(args, tuple): - return tuple([clone_if_tensor(arg) for arg in args]) - elif isinstance(args, torch.Tensor): - return args.clone() - else: - return args - - @torch.no_grad - def wrap_hook_setup(setup): - def wrapped_setup(*args, **kwargs): - args = setup(*args, **kwargs) - args = clone_if_tensor(args) - return args - - return wrapped_setup - - BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook) - if not self.optimizer_hooked: self.hook_optimizer() return @@ -822,6 +801,8 @@ class TrainerMon: hooked_count = 0 if self.xy_distribution or self.print_struct: for module_name, submodule in module.named_modules(): + if submodule._modules: + continue name = self._is_target_module(module_name, target_names, vpp_stage) if not name: continue -- Gitee