diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index ab4533a879e484ab05780dff3b9252b61041a02b..72e8ed4d9068f93715605b10387a03e497259ce0 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -28,6 +28,8 @@ msprobe 工具通过在训练脚本中添加 `PrecisionDebugger` 接口并启动 dump "statistics"模式的性能膨胀大小"与"tensor"模式采集的数据量大小,可以参考[dump基线](data_dump_MindSpore/data_dump_MindSpore_baseline.md)。 +**注意**:因 MindSpore 框架自动微分机制的限制,dump 数据中可能会缺少原地操作模块/API 及其上一个模块/API 的反向数据。 + ## 5. 场景介绍 ### 5.1 静态图场景 diff --git a/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md index a290d349de1daadb881993e67bfd45a367bd8400..6439a0a6eb1e5a12829c7f10ab0aa1baccac85d0 100644 --- a/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md +++ b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md @@ -4,7 +4,11 @@ MSAdapter 是一款 MindSpore 生态适配工具,可以将 PyTorch 训练脚 msprobe 工具主要通过在训练脚本内添加 dump 接口、启动训练的方式采集精度数据。 -**注意**:为了正确识别 MSAdapter 场景,在导入 msprobe 工具前,需完成 torch 模块的的导入。 +**注意**: + +- 为了正确识别 MSAdapter 场景,在导入 msprobe 工具前,需完成 torch 模块的导入。 + +- 因 MindSpore 框架自动微分机制的限制,dump 数据中可能会缺少原地操作模块/API 及其上一个模块/API 的反向数据。 本工具提供固定的 API 支持列表,若需要删除或增加 dump 的 API,可以在 msprobe/pytorch/hook_module/support_wrap_ops.yaml 文件内手动修改,如下示例: diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index cc90cb03e0e0377c6ea58e9ba9be60439d004777..71c8d72461e0e3e27952cd77e61ab1c24dfde8cd 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -29,7 +29,8 @@ from msprobe.mindspore.common.utils import ( is_mindtorch, get_cells_and_names_with_index, has_kwargs_in_forward_hook, - is_graph_mode_cell_dump_allowed + is_graph_mode_cell_dump_allowed, + is_backward_hook_output_a_view ) from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump @@ -166,7 +167,7 @@ class CellProcessor: bw_hook.register_backward_hook() CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook - args = bw_hook(*args) + args = bw_hook(args) if is_backward_hook_output_a_view() else bw_hook(*args) return args @@ -190,12 +191,15 @@ class CellProcessor: logger.warning("For backward hooks to be called," " cell output should be a Tensor or a tuple of Tensors" f" but received {type(outputs)}") - if isinstance(outputs, tuple): - new_outputs = bw_hook(*outputs) - else: + if is_backward_hook_output_a_view(): new_outputs = bw_hook(outputs) - if isinstance(outputs, tuple) and len(outputs) == 1: - new_outputs = (new_outputs,) + else: + if isinstance(outputs, tuple): + new_outputs = bw_hook(*outputs) + else: + new_outputs = bw_hook(outputs) + if isinstance(outputs, tuple) and len(outputs) == 1: + new_outputs = (new_outputs,) outputs = new_outputs def get_backward_pre_hook(full_backward_name, backward_data_hook): @@ -216,18 +220,21 @@ class CellProcessor: self.cell_backward_pre_hook[-1]) bw_pre_hook.register_backward_pre_hook() - if isinstance(outputs, tuple): - result = bw_pre_hook(*outputs) - else: + if is_backward_hook_output_a_view(): result = bw_pre_hook(outputs) - if isinstance(outputs, tuple): - if len(outputs) == 1: - result = (result,) - if len(result) != len(outputs): - raise TypeError( - f"The backward pre hook return value size is {len(result)} " - f"not equal to output size {len(outputs)}" - ) + else: + if isinstance(outputs, tuple): + result = bw_pre_hook(*outputs) + else: + result = bw_pre_hook(outputs) + if isinstance(outputs, tuple): + if len(outputs) == 1: + result = (result,) + if len(result) != len(outputs): + raise TypeError( + f"The backward pre hook return value size is {len(result)} " + f"not equal to output size {len(outputs)}" + ) return result return forward_pre_hook diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 2d9d0d3d9b5c932fd5f14e682f8579a69ac74810..d35e1b51945c8f5a63f99edacadc3885570790ef 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -42,6 +42,7 @@ else: mindtorch_check_result = None register_backward_hook_functions = {} kwargs_exist_in_forward_hook = None +is_output_of_backward_hook_a_view = None class MsprobeStep(ms.train.Callback): @@ -329,3 +330,43 @@ def has_kwargs_in_forward_hook(): return kwargs_exist_in_forward_hook return kwargs_exist_in_forward_hook + + +def is_backward_hook_output_a_view(): + global is_output_of_backward_hook_a_view + + if is_output_of_backward_hook_a_view is None: + is_output_of_backward_hook_a_view = False + if getattr(ms, '__version__', '2.4.0') < '2.7.0': + return is_output_of_backward_hook_a_view + try: + from mindspore.ops.operations import _inner_ops as inner + call_func = getattr(inner.CellBackwardHook, '__call__') + func_params = inspect.signature(call_func).parameters + except Exception: + return is_output_of_backward_hook_a_view + if 'args' in func_params and func_params['args'].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + is_output_of_backward_hook_a_view = True + + return is_output_of_backward_hook_a_view + + +def wrap_backward_hook_call_func(call_func): + if not is_backward_hook_output_a_view(): + return call_func + + from mindspore.common.api import _pynative_executor as executor + from mindspore._c_expression import CreationType + + def new_call(self, args): + outputs = call_func(self, args) + if isinstance(outputs, ms.Tensor): + executor.set_creation_type(outputs, CreationType.DEFAULT) + elif isinstance(outputs, tuple): + for item in outputs: + if isinstance(item, ms.Tensor): + executor.set_creation_type(item, CreationType.DEFAULT) + return outputs + new_call.__name__ = '__call__' + + return new_call diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index b2e7289a5b14a2a63eb09fff7ac8f1cf624babfc..cd9e2e05b24a8d4bdce289d206fc3d0cbf2650cd 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -17,6 +17,7 @@ import os from collections import defaultdict, namedtuple import mindspore as ms +from mindspore.ops.operations import _inner_ops as inner from mindspore._c_expression import MSContext from msprobe.core.common.const import Const, MsgConst @@ -28,7 +29,8 @@ from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.utils import ( set_register_backward_hook_functions, check_save_param, - is_graph_mode_cell_dump_allowed + is_graph_mode_cell_dump_allowed, + wrap_backward_hook_call_func ) from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump @@ -81,6 +83,9 @@ class PrecisionDebugger(BasePrecisionDebugger): self.common_config.dump_path = dump_path if dump_path else self.common_config.dump_path self.config = DebuggerConfig(self.common_config, self.task_config) + setattr(inner.CellBackwardHook, '__call__', + wrap_backward_hook_call_func(getattr(inner.CellBackwardHook, '__call__'))) + if self._is_kernel_dump() and _msprobe_c: os.environ["MS_HOOK_ENABLE"] = "on" _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py index ddee76596452323bf564270f526de9103a6c1a8d..d82ec093725f815f368cef52d35067411a67795b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -69,7 +69,7 @@ else: } _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module', MsConst.SUPPORTED_API_LIST_FILE),) - _backlist = [f'{Const.PT_API_TYPE_TENSOR}.__setitem__'] + _backlist = [] _inner_used_api = { Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: ( diff --git a/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py b/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py index 27e42d52ba6190ec7e7531af25464e6aa3996b2b..7ca256dbcf8f9011f5ca84898f643888ea7f890e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py +++ b/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py @@ -93,6 +93,8 @@ from torch.nn.modules.module import (_global_backward_pre_hooks, _global_backwar _global_forward_hooks, _global_forward_hooks_always_called) from torch.utils.hooks import RemovableHandle +from msprobe.mindspore.common.utils import is_backward_hook_output_a_view + def _call_impl(self, *args, **kwargs): forward_call = self.forward @@ -245,11 +247,14 @@ def _get_backward_hooks(self): def apply_backward_hook_on_tensors(cell_backward_hook, args): - is_tuple = True - if not isinstance(args, tuple): - args = (args,) - is_tuple = False - hooked_args = cell_backward_hook(*args) - if is_tuple and len(args) == 1: - hooked_args = (hooked_args, ) + if is_backward_hook_output_a_view(): + hooked_args = cell_backward_hook(args) + else: + is_tuple = True + if not isinstance(args, tuple): + args = (args,) + is_tuple = False + hooked_args = cell_backward_hook(*args) + if is_tuple and len(args) == 1: + hooked_args = (hooked_args, ) return hooked_args