diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index e832348fb7957f8d7dc624faa04697ee78615470..1f2a652b27b7e5b58db1bb3b5e8d9b0addc50ae1 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -87,6 +87,8 @@ class Const: INPUT_KWARGS = 'input_kwargs' GRAD_INPUT = 'grad_input' GRAD_OUTPUT = 'grad_output' + PARAMS = 'parameters' + PARAMS_GRAD = 'parameters_grad' START = "start" STOP = "stop" ENV_ENABLE = "1" @@ -125,6 +127,7 @@ class Const: DISTRIBUTED = 'Distributed' DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive", "Aten", "VF", "NPU", "Jit"] + MODULE_PREFIX = ["Module", "Cell"] # struct json param ORIGIN_DATA = "origin_data" diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index dc7ae7bcddccebc4c310d5bc466502abe5e8d22d..eac5110c911e944e46bd90b592e8dc49eab264cb 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -38,6 +38,7 @@ class DataCollector: self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) self.module_count = {} self.scope = ScopeFactory(self.config).build_scope() + self.backward_module_names = {} atexit.register(self.write_json) @property @@ -116,6 +117,12 @@ class DataCollector: data_info = self.data_processor.analyze_backward(name, module, module_input_output) if self.config.level == Const.LEVEL_L2: return + + # 获取执行反向的模块名称 + if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX: + module_name = name.rsplit(Const.SEP, 2)[0] + # 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度 + self.backward_module_names[module_name] = True self.handle_data(name, data_info, flush=self.data_processor.is_terminated) def backward_input_data_collect(self, name, module, pid, module_input_output): @@ -153,3 +160,13 @@ class DataCollector: def update_iter(self, current_iter): self.data_processor.update_iter(current_iter) + + def params_data_collect(self, name, param_name, pid, data): + grad_name = name + Const.SEP + Const.PARAMS_GRAD + # 校验scope和pid,以及当前name是否有过反向计算 + if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name): + # 如果没有反向计算,则需要清除之前占位写入的grad数据 + self.data_writer.cache_data.get("data").pop(grad_name, None) + return + data_info = self.data_processor.analyze_params(grad_name, param_name, data) + self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 07bc4d6dbc085f3e1d9c94d6a480696e84cd40b3..4b540a9d947948acc24e9f85ae1c673e644a06fb 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -102,6 +102,7 @@ class BaseDataProcessor: self.current_iter = 0 self._return_forward_new_output = False self._forward_new_output = None + self.save_name = None if hasattr(config, "data_mode"): self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode) @@ -297,6 +298,11 @@ class BaseDataProcessor: self.api_data_category = Const.OUTPUT output_info_list = self.analyze_element(module_input_output.output_tuple) api_info_struct[name][Const.OUTPUT] = output_info_list + + if name in api_info_struct and hasattr(module_input_output, Const.PARAMS): + self.api_data_category = Const.PARAMS + api_info_struct[name][Const.PARAMS] = self.analyze_element(getattr(module_input_output, Const.PARAMS)) + return api_info_struct def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs): @@ -359,9 +365,21 @@ class BaseDataProcessor: api_info_struct[name][Const.OUTPUT] = output_info_list return api_info_struct + def analyze_params(self, name, param_name, grad): + api_info_struct = {} + self.save_name = name + Const.SEP + param_name + data_info = self.analyze_element(grad) + grad_info_dict = {param_name: [data_info]} + api_info_struct[name] = grad_info_dict + return api_info_struct + def get_save_file_path(self, suffix): file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX - dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + - suffix + file_format) + if self.save_name is not None: + dump_data_name = (self.save_name + file_format) + self.save_name = None + else: + dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + + suffix + file_format) file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) return dump_data_name, file_path diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 8a8ae763714869442907d38e3b913f6245659bed..a980d6ac80fcb0626547beffa6e8d1c655403e8c 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -161,6 +161,12 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): api_info_struct = super().analyze_backward(name, module, module_input_output) self.maybe_save_overflow_data() return api_info_struct if self.has_overflow else None + + def analyze_params(self, name, param_name, grad): + self.has_overflow = False + api_info_struct = super().analyze_params(name, param_name, grad) + self.maybe_save_overflow_data() + return api_info_struct if self.has_overflow else None def maybe_save_overflow_data(self): if self.has_overflow: diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 0130b7eda12851e2e4782ce9e6a5f0cf87fc8384..71c771805d249e8db216bd7a63ea18a26946489a 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -266,6 +266,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): api_info_struct = super().analyze_backward(name, module, module_input_output) self.handle_overflow() return api_info_struct if self.has_overflow else None + + def analyze_params(self, name, param_name, grad): + self.has_overflow = False + self._is_support_inf_nan() + api_info_struct = super().analyze_params(name, param_name, grad) + self.handle_overflow() + return api_info_struct if self.has_overflow else None def handle_overflow(self): if not self.support_inf_nan: diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 22038f351b76130cd57b7768c820040e4f19ba21..28d3f9230a097388250e6e17526e1dd1e5693487 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -54,6 +54,11 @@ def convert_to_int(value): return int(value) except Exception: return -1 + + +def clean_input_kwargs(cell): + if hasattr(cell, 'input_kwargs'): + del cell.input_kwargs def list_lowest_level_directories(root_dir): diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 9217bd3efeb1436f33ad57661e2f7ac439660906..3d38e9c7e76cf9cff4eabd94cd6174200e57f368 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -36,7 +36,7 @@ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutpu from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import get_rank_if_initialized +from msprobe.mindspore.common.utils import get_rank_if_initialized, clean_input_kwargs from msprobe.mindspore.dump.hook_cell.api_registry import api_register from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService from msprobe.mindspore.dump.jit_dump import JitDump @@ -67,6 +67,14 @@ class Service: raise MsprobeException( MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。" ) + + @staticmethod + def prepare_module_input_output(target_type, cell, input_data, output): + if target_type == BaseScope.Module_Type_Module: + module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output) + else: + module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output) + return module_input_output def check_level_valid(self): if self.config.level == Const.LEVEL_L2: @@ -74,26 +82,52 @@ class Service: MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported." ) - def build_hook(self, target_type, name): + def build_hook(self, target_type, name): + def grad_hook(ori_name, param_name): + def hook_fn(grad): + if not self.should_excute_hook(): + return None + self.data_collector.params_data_collect(ori_name, param_name, pid, grad) + return None + return hook_fn + + def register_param_hook(cell_name, cell, params_dict): + # data_mode为forward时,不注册参数hook + if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): + # 判断参数是否已经注册过hook + if params_dict and hasattr(cell, 'has_param_hook') and not cell.has_param_hook: + ori_name = cell_name.rsplit(Const.SEP, 2)[0] + grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD + # 注册hook时,初始化grad_name的data_info + data_info = {grad_name: {key: [None] for key in params_dict}} + # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 + self.data_collector.handle_data(grad_name, data_info, flush=self.data_collector.data_processor.is_terminated) + for param_name, param in params_dict.items(): + param.register_hook(grad_hook(ori_name, param_name)) + cell.has_param_hook = True + def forward_hook(api_or_cell_name, cell, input_data, output): if not self.should_excute_hook(): - if hasattr(cell, 'input_kwargs'): - del cell.input_kwargs + clean_input_kwargs(cell) return None - + + module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output) + params_dict = {} if target_type == BaseScope.Module_Type_Module: api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) - module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output) - else: - module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, - output=output) + params_dict = {key.split(Const.SEP)[-1]: value for key, value in cell.parameters_dict(recurse=False).items()} + setattr(module_input_output, Const.PARAMS, params_dict) + # 设置has_param_hook属性,避免重复注册hook + if not hasattr(cell, 'has_param_hook'): + setattr(cell, 'has_param_hook', False) self.data_collector.update_api_or_module_name(api_or_cell_name) self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output) + register_param_hook(api_or_cell_name, cell, params_dict) + if self.data_collector.if_return_forward_new_output(): return self.data_collector.get_forward_new_output() - if hasattr(cell, 'input_kwargs'): - del cell.input_kwargs + clean_input_kwargs(cell) return output def backward_hook(api_or_cell_name, cell, grad_input, grad_output): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index fcc6282207e9f12ecddcccd8131755093b7c12d9..95c2c5867c231a472b424bf0b7dda0e26a9d776f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -89,12 +89,43 @@ class Service: self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output) return args, kwargs + def grad_hook(ori_name, param_name): + def hook_fn(grad): + if not self.should_execute_hook(): + return grad + self.data_collector.params_data_collect(ori_name, param_name, pid, grad) + return grad + return hook_fn + + def register_param_hook(module_name, module, params_dict): + # data_mode为forward时,不注册参数hook + if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): + # 判断参数是否已经注册过hook + if params_dict and hasattr(module, 'has_param_hook') and not module.has_param_hook: + ori_name = module_name.rsplit(Const.SEP, 2)[0] + grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD + # 注册hook时,初始化grad_name的data_info + data_info = {grad_name: {key: [None] for key in params_dict}} + # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 + self.data_collector.handle_data(grad_name, data_info, flush=self.data_collector.data_processor.is_terminated) + for param_name, param in params_dict.items(): + param.register_hook(grad_hook(ori_name, param_name)) + module.has_param_hook = True + def forward_hook(api_or_module_name, module, args, kwargs, output): if not self.should_execute_hook(): return None + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + params_dict = {} if module_type == BaseScope.Module_Type_Module: api_or_module_name = module.mindstudio_reserved_name + params_dict = {key.split(Const.SEP)[-1]: value for key, value in module.named_parameters(recurse=False)} + setattr(module_input_output, Const.PARAMS, params_dict) + # 设置has_param_hook属性,避免重复注册hook + if not hasattr(module, 'has_param_hook'): + setattr(module, 'has_param_hook', False) + self.data_collector.update_api_or_module_name(api_or_module_name) if self.config.online_run_ut: @@ -103,12 +134,12 @@ class Service: api_data = ApiData(name[:-1], args, kwargs, output, self.current_iter, self.current_rank) self.attl_send(api_data) return None + + self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output) + register_param_hook(api_or_module_name, module, params_dict) - if self.data_collector: - module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) - self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output) - if self.data_collector.if_return_forward_new_output(): - return self.data_collector.get_forward_new_output() + if self.data_collector.if_return_forward_new_output(): + return self.data_collector.get_forward_new_output() return output def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):