diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py index ef72a75ca246a8943bf580ba490465d2cca2c09b..cde54f22f78d35f0e9986e88827a1587814ccb7f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py @@ -16,12 +16,21 @@ from mindspore import nn from mindspore import communication +from mindspore import Tensor from msprobe.mindspore.monitor.utils import logger from msprobe.mindspore.common.utils import is_mindtorch if is_mindtorch(): import torch +def check_tensor(tensor): + if tensor is None: + return Tensor([]) + elif tensor is not None and not isinstance(tensor, Tensor): + return Tensor(tensor) + return tensor + + def is_valid_instance(model): return isinstance(model, torch.nn.Module) if is_mindtorch() else isinstance(model, nn.Cell) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 38e0f470c706f919dfa8729feb89244a002466fd..b3887f031002b88fe5164d45b633c3a7ca93f2ac 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -28,7 +28,8 @@ from msprobe.core.common.log import logger from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import load_json, save_json from msprobe.mindspore.common.utils import is_mindtorch -from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank +from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank, \ + check_tensor from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ is_skip_step, get_metrics, get_single_metrics, get_target_output_dir from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec @@ -719,10 +720,12 @@ class TrainerMon: tbtag_tensor_map = {} if not context.ignore_in: cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] + cared_input = check_tensor(cared_input) tbtag_tensor_map.update( self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN, cared_input)) cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] + cared_output = check_tensor(cared_output) tbtag_tensor_map.update( self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT, cared_output)) @@ -768,10 +771,12 @@ class TrainerMon: tbtag_tensor_map = {} if not context.ignore_in: cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] + cared_input_grad = check_tensor(cared_input_grad) tbtag_tensor_map.update( self.build_tbtag_tensor_map( f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad)) cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] + cared_output_grad = check_tensor(cared_output_grad) tbtag_tensor_map.update( self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT, cared_output_grad))