From 627235a89078cd76cd487f37e15162cbf8ea2912 Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Fri, 16 May 2025 16:31:32 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E4=BF=AE=E5=A4=8Dm?= =?UTF-8?q?onitor=E6=BF=80=E6=B4=BB=E5=80=BC=E9=87=87=E9=9B=86=E7=BC=BA?= =?UTF-8?q?=E5=A4=B1=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/mindspore/monitor/common_func.py | 9 +++++++++ .../msprobe/mindspore/monitor/module_hook.py | 7 ++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py index ef72a75ca24..133dffd0654 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py @@ -16,12 +16,21 @@ from mindspore import nn from mindspore import communication +from mindspore import Tensor from msprobe.mindspore.monitor.utils import logger from msprobe.mindspore.common.utils import is_mindtorch if is_mindtorch(): import torch +def check_tensor(tensor): + if tensor is None: + return Tensor([]) + elif tensor is not None and not isinstance(tensor, Tensor): + return Tensor([tensor]) + return tensor + + def is_valid_instance(model): return isinstance(model, torch.nn.Module) if is_mindtorch() else isinstance(model, nn.Cell) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 38e0f470c70..b3887f03100 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -28,7 +28,8 @@ from msprobe.core.common.log import logger from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import load_json, save_json from msprobe.mindspore.common.utils import is_mindtorch -from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank +from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank, \ + check_tensor from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ is_skip_step, get_metrics, get_single_metrics, get_target_output_dir from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec @@ -719,10 +720,12 @@ class TrainerMon: tbtag_tensor_map = {} if not context.ignore_in: cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] + cared_input = check_tensor(cared_input) tbtag_tensor_map.update( self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN, cared_input)) cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] + cared_output = check_tensor(cared_output) tbtag_tensor_map.update( self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT, cared_output)) @@ -768,10 +771,12 @@ class TrainerMon: tbtag_tensor_map = {} if not context.ignore_in: cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] + cared_input_grad = check_tensor(cared_input_grad) tbtag_tensor_map.update( self.build_tbtag_tensor_map( f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad)) cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] + cared_output_grad = check_tensor(cared_output_grad) tbtag_tensor_map.update( self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT, cared_output_grad)) -- Gitee From 2acb0ededc81cee238f22d8e410a3de0bd3fbdd7 Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Fri, 16 May 2025 16:31:32 +0800 Subject: [PATCH 2/2] fix --- debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py index 133dffd0654..cde54f22f78 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py @@ -27,7 +27,7 @@ def check_tensor(tensor): if tensor is None: return Tensor([]) elif tensor is not None and not isinstance(tensor, Tensor): - return Tensor([tensor]) + return Tensor(tensor) return tensor -- Gitee