From 8420ab4c823a48a6de5b76f2cbab5e022e3b9fc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=A4=E5=AE=89=E5=8D=87?= Date: Fri, 1 Jul 2022 09:29:52 +0800 Subject: [PATCH] Revert to deepcopy. --- torch_npu/utils/serialization.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/torch_npu/utils/serialization.py b/torch_npu/utils/serialization.py index ab8e3fd6a24..be632a25d3f 100644 --- a/torch_npu/utils/serialization.py +++ b/torch_npu/utils/serialization.py @@ -33,9 +33,7 @@ def to_cpu(data): return data.cpu() if isinstance(data, nn.Module): - if torch_npu._C.is_npu(next(data.parameters())): - setattr(data, "mark_npu", True) - return data.cpu() + return copy.deepcopy(data).cpu() if isinstance(data, argparse.Namespace): dict_obj = vars(data) @@ -55,9 +53,7 @@ def to_cpu(data): elif isinstance(value, torch.Tensor): copy_data[i] = value.cpu() elif isinstance(value, nn.Module): - if torch_npu._C.is_npu(next(value.parameters())): - setattr(value, "mark_npu", True) - copy_data[i] = value.cpu() + copy_data[i] = copy.deepcopy(value).cpu() else: copy_data[i] = value return type(data)(copy_data) @@ -74,9 +70,7 @@ def to_cpu(data): elif isinstance(value, torch.Tensor): copy_data[key] = value.cpu() elif isinstance(value, nn.Module): - if torch_npu._C.is_npu(next(value.parameters())): - setattr(value, "mark_npu", True) - copy_data[key] = value.cpu() + copy_data[key] = copy.deepcopy(value).cpu() else: copy_data[key] = value return copy_data @@ -84,24 +78,6 @@ def to_cpu(data): return data -def module_to_npu(data): - if isinstance(data, nn.Module) and hasattr(data, "mark_npu"): - delattr(data, "mark_npu") - data.npu() - - if isinstance(data, container_abcs.Sequence): - for value in data: - if isinstance(value, nn.Module) and hasattr(value, "mark_npu"): - delattr(value, "mark_npu") - value.npu() - - if isinstance(data, container_abcs.Mapping): - for _, value in data.items(): - if isinstance(value, nn.Module) and hasattr(value, "mark_npu"): - delattr(value, "mark_npu") - value.npu() - - def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=False): """Saves the input data into a file. @@ -118,7 +94,6 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne the same host will override each other. """ se.save(to_cpu(obj), f, pickle_module, pickle_protocol, _use_new_zipfile_serialization) - module_to_npu(obj) def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): -- Gitee