diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index d9b67c93175641cbd0009c1f58562787626d375f..4f148008ebf9bc40b76535d1adae476d6a8bf7c1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -47,30 +47,33 @@ class ModuleProcesser: def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None - BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) - BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + BackwardHook.setup_input_hook = ModuleProcesser.modify_view_type_return_value(BackwardHook.setup_input_hook) + BackwardHook.setup_output_hook = ModuleProcesser.modify_view_type_return_value(BackwardHook.setup_output_hook) replace_checkpoint() @staticmethod - def clone_return_value(func): + def modify_view_type_return_value(func): @wraps(func) - def clone_return_value_func(*args, **kwargs): + def modify_view_type_return_value_func(*args, **kwargs): result = func(*args, **kwargs) - return ModuleProcesser.clone_if_tensor(result) + return ModuleProcesser.modify_view_type(result) - return clone_return_value_func + return modify_view_type_return_value_func @staticmethod - @recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH) - def clone_if_tensor(result): + @recursion_depth_decorator("ModuleDump: ModuleProcesser.modify_view_type", max_depth=Const.DUMP_MAX_DEPTH) + def modify_view_type(result): if isinstance(result, torch.Tensor) and not is_float8_tensor(result): - return result.clone() + if hasattr(result, "_base") and result._base is not None: + if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0): + torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0)) + return result elif type(result) is tuple: - return tuple(ModuleProcesser.clone_if_tensor(x) for x in result) + return tuple(ModuleProcesser.modify_view_type(x) for x in result) elif type(result) is list: - return list(ModuleProcesser.clone_if_tensor(x) for x in result) + return list(ModuleProcesser.modify_view_type(x) for x in result) elif type(result) is dict: - return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()} + return {k: ModuleProcesser.modify_view_type(v) for k, v in result.items()} else: return result