From 5ed2b26bbd3a9080b3f3158c8680c73220210baa Mon Sep 17 00:00:00 2001 From: l30036321 Date: Thu, 27 Mar 2025 20:08:07 +0800 Subject: [PATCH] modify view type replace clone --- .../dump/module_dump/module_processer.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index d9b67c9317..4f148008eb 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -47,30 +47,33 @@ class ModuleProcesser: def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None - BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) - BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + BackwardHook.setup_input_hook = ModuleProcesser.modify_view_type_return_value(BackwardHook.setup_input_hook) + BackwardHook.setup_output_hook = ModuleProcesser.modify_view_type_return_value(BackwardHook.setup_output_hook) replace_checkpoint() @staticmethod - def clone_return_value(func): + def modify_view_type_return_value(func): @wraps(func) - def clone_return_value_func(*args, **kwargs): + def modify_view_type_return_value_func(*args, **kwargs): result = func(*args, **kwargs) - return ModuleProcesser.clone_if_tensor(result) + return ModuleProcesser.modify_view_type(result) - return clone_return_value_func + return modify_view_type_return_value_func @staticmethod - @recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH) - def clone_if_tensor(result): + @recursion_depth_decorator("ModuleDump: ModuleProcesser.modify_view_type", max_depth=Const.DUMP_MAX_DEPTH) + def modify_view_type(result): if isinstance(result, torch.Tensor) and not is_float8_tensor(result): - return result.clone() + if hasattr(result, "_base") and result._base is not None: + if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0): + torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0)) + return result elif type(result) is tuple: - return tuple(ModuleProcesser.clone_if_tensor(x) for x in result) + return tuple(ModuleProcesser.modify_view_type(x) for x in result) elif type(result) is list: - return list(ModuleProcesser.clone_if_tensor(x) for x in result) + return list(ModuleProcesser.modify_view_type(x) for x in result) elif type(result) is dict: - return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()} + return {k: ModuleProcesser.modify_view_type(v) for k, v in result.items()} else: return result -- Gitee