diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index fa0a3e891a68c93d4f7add65e6d628d002198f27..9a5e2154c15c8a71db804823d33cb522b9658c0f 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -28,6 +28,8 @@ msprobe 工具通过在训练脚本中添加 `PrecisionDebugger` 接口并启动 dump "statistics"模式的性能膨胀大小"与"tensor"模式采集的数据量大小,可以参考[dump基线](data_dump_MindSpore/data_dump_MindSpore_baseline.md)。 +**注意**:因 MindSpore 框架自动微分机制的限制,dump 数据中可能会缺少原地操作模块/API 及其上一个模块/API 的反向数据。 + ## 5. 场景介绍 ### 5.1 静态图场景 diff --git a/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md index a290d349de1daadb881993e67bfd45a367bd8400..6439a0a6eb1e5a12829c7f10ab0aa1baccac85d0 100644 --- a/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md +++ b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md @@ -4,7 +4,11 @@ MSAdapter 是一款 MindSpore 生态适配工具,可以将 PyTorch 训练脚 msprobe 工具主要通过在训练脚本内添加 dump 接口、启动训练的方式采集精度数据。 -**注意**:为了正确识别 MSAdapter 场景,在导入 msprobe 工具前,需完成 torch 模块的的导入。 +**注意**: + +- 为了正确识别 MSAdapter 场景,在导入 msprobe 工具前,需完成 torch 模块的导入。 + +- 因 MindSpore 框架自动微分机制的限制,dump 数据中可能会缺少原地操作模块/API 及其上一个模块/API 的反向数据。 本工具提供固定的 API 支持列表,若需要删除或增加 dump 的 API,可以在 msprobe/pytorch/hook_module/support_wrap_ops.yaml 文件内手动修改,如下示例: diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index 468402e6e2855b0b3d52e7e7997fc431ff966902..d96e8c0a590e1a8a16ab6cfd1ac60d5e51592df4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -31,7 +31,8 @@ from msprobe.mindspore.common.utils import ( is_mindtorch, get_cells_and_names_with_index, has_kwargs_in_forward_hook, - is_graph_mode_cell_dump_allowed + is_graph_mode_cell_dump_allowed, + is_backward_hook_output_a_view ) from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump @@ -174,7 +175,7 @@ class CellProcessor: bw_hook.register_backward_hook() CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook - args = bw_hook(*args) + args = bw_hook(args) if is_backward_hook_output_a_view() else bw_hook(*args) return args @@ -199,12 +200,15 @@ class CellProcessor: logger.warning("For backward hooks to be called," " cell output should be a Tensor or a tuple of Tensors" f" but received {type(outputs)}") - if isinstance(outputs, tuple): - new_outputs = bw_hook(*outputs) - else: + if is_backward_hook_output_a_view(): new_outputs = bw_hook(outputs) - if isinstance(outputs, tuple) and len(outputs) == 1: - new_outputs = (new_outputs,) + else: + if isinstance(outputs, tuple): + new_outputs = bw_hook(*outputs) + else: + new_outputs = bw_hook(outputs) + if isinstance(outputs, tuple) and len(outputs) == 1: + new_outputs = (new_outputs,) outputs = new_outputs def get_backward_pre_hook(full_backward_name, backward_data_hook): @@ -227,18 +231,21 @@ class CellProcessor: self.cell_backward_pre_hook[-1]) bw_pre_hook.register_backward_pre_hook() - if isinstance(outputs, tuple): - result = bw_pre_hook(*outputs) - else: + if is_backward_hook_output_a_view(): result = bw_pre_hook(outputs) - if isinstance(outputs, tuple): - if len(outputs) == 1: - result = (result,) - if len(result) != len(outputs): - raise TypeError( - f"The backward pre hook return value size is {len(result)} " - f"not equal to output size {len(outputs)}" - ) + else: + if isinstance(outputs, tuple): + result = bw_pre_hook(*outputs) + else: + result = bw_pre_hook(outputs) + if isinstance(outputs, tuple): + if len(outputs) == 1: + result = (result,) + if len(result) != len(outputs): + raise TypeError( + f"The backward pre hook return value size is {len(result)} " + f"not equal to output size {len(outputs)}" + ) return result return forward_pre_hook diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index e4646e906d0fb66e697e3627e80d1580e902b7d9..386bd90840792ca941f4b6769070b71b263cfcd6 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -42,6 +42,7 @@ else: mindtorch_check_result = None register_backward_hook_functions = {} kwargs_exist_in_forward_hook = None +is_output_of_backward_hook_a_view = None class MsprobeStep(ms.train.Callback): @@ -329,3 +330,43 @@ def has_kwargs_in_forward_hook(): return kwargs_exist_in_forward_hook return kwargs_exist_in_forward_hook + + +def is_backward_hook_output_a_view(): + global is_output_of_backward_hook_a_view + + if is_output_of_backward_hook_a_view is None: + is_output_of_backward_hook_a_view = False + if getattr(ms, '__version__', '2.4.0') < '2.7.0': + return is_output_of_backward_hook_a_view + try: + from mindspore.ops.operations import _inner_ops as inner + call_func = getattr(inner.CellBackwardHook, '__call__') + func_params = inspect.signature(call_func).parameters + except Exception: + return is_output_of_backward_hook_a_view + if 'args' in func_params and func_params['args'].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + is_output_of_backward_hook_a_view = True + + return is_output_of_backward_hook_a_view + + +def wrap_backward_hook_call_func(call_func): + if not is_backward_hook_output_a_view(): + return call_func + + from mindspore.common.api import _pynative_executor as executor + from mindspore._c_expression import CreationType + + def new_call(self, args): + outputs = call_func(self, args) + if isinstance(outputs, ms.Tensor): + executor.set_creation_type(outputs, CreationType.DEFAULT) + elif isinstance(outputs, tuple): + for item in outputs: + if isinstance(item, ms.Tensor): + executor.set_creation_type(item, CreationType.DEFAULT) + return outputs + new_call.__name__ = '__call__' + + return new_call diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 420f15470c4d1ac0d58e64f5e30c49fed0841ee7..b5b83e14b25359bf6774dbd91c063b4454e1e756 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -17,6 +17,7 @@ import os from collections import defaultdict, namedtuple import mindspore as ms +from mindspore.ops.operations import _inner_ops as inner from mindspore._c_expression import MSContext from msprobe.core.common.const import Const, MsgConst @@ -28,7 +29,8 @@ from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.utils import ( set_register_backward_hook_functions, check_save_param, - is_graph_mode_cell_dump_allowed + is_graph_mode_cell_dump_allowed, + wrap_backward_hook_call_func ) from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump @@ -77,6 +79,9 @@ class PrecisionDebugger(BasePrecisionDebugger): self.common_config.dump_path = dump_path if dump_path else self.common_config.dump_path self.config = DebuggerConfig(self.common_config, self.task_config) + setattr(inner.CellBackwardHook, '__call__', + wrap_backward_hook_call_func(getattr(inner.CellBackwardHook, '__call__'))) + if self._is_kernel_dump() and _msprobe_c: os.environ["MS_HOOK_ENABLE"] = "on" _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py index ddee76596452323bf564270f526de9103a6c1a8d..d82ec093725f815f368cef52d35067411a67795b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -69,7 +69,7 @@ else: } _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module', MsConst.SUPPORTED_API_LIST_FILE),) - _backlist = [f'{Const.PT_API_TYPE_TENSOR}.__setitem__'] + _backlist = [] _inner_used_api = { Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: ( diff --git a/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py b/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py index 27e42d52ba6190ec7e7531af25464e6aa3996b2b..7ca256dbcf8f9011f5ca84898f643888ea7f890e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py +++ b/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py @@ -93,6 +93,8 @@ from torch.nn.modules.module import (_global_backward_pre_hooks, _global_backwar _global_forward_hooks, _global_forward_hooks_always_called) from torch.utils.hooks import RemovableHandle +from msprobe.mindspore.common.utils import is_backward_hook_output_a_view + def _call_impl(self, *args, **kwargs): forward_call = self.forward @@ -245,11 +247,14 @@ def _get_backward_hooks(self): def apply_backward_hook_on_tensors(cell_backward_hook, args): - is_tuple = True - if not isinstance(args, tuple): - args = (args,) - is_tuple = False - hooked_args = cell_backward_hook(*args) - if is_tuple and len(args) == 1: - hooked_args = (hooked_args, ) + if is_backward_hook_output_a_view(): + hooked_args = cell_backward_hook(args) + else: + is_tuple = True + if not isinstance(args, tuple): + args = (args,) + is_tuple = False + hooked_args = cell_backward_hook(*args) + if is_tuple and len(args) == 1: + hooked_args = (hooked_args, ) return hooked_args diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py index 0be8de96890da80a503ae68ac4f166b018adf9ec..474a2168fbe09387f605d839faa190fc1dcf17dc 100644 --- a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -44,9 +44,11 @@ importlib.reload(mindspore_service) importlib.reload(common_func) reset_torch_tensor() + def register_backward_pre_hook(*args, **kwargs): pass + register_backward_hook_functions['full'] = ms.nn.Cell.register_backward_hook register_backward_hook_functions["pre"] = register_backward_pre_hook