From 8e94d1660054096f65d9d0a9bb75076768b7ab8b Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Mon, 21 Mar 2022 18:58:25 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=85Module=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_api/test_torch/test_serialization.py | 9 +++++---- torch_npu/utils/serialization.py | 11 ++--------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/test/test_api/test_torch/test_serialization.py b/test/test_api/test_torch/test_serialization.py index 6d4b315cf71..6d00c022e4b 100644 --- a/test/test_api/test_torch/test_serialization.py +++ b/test/test_api/test_torch/test_serialization.py @@ -56,13 +56,14 @@ class TestSerialization(TestCase): def test_save_tuple(self): x = torch.randn(5).npu() + model = NpuMNIST().npu() number = 3 with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, 'data.pt') - torch.save((x, number), path) - x_loaded, number_loaded = torch.load(path) - x_loaded = x_loaded.npu() - self.assertRtolEqual(x.cpu(), x_loaded.cpu()) + torch.save((x, model, number), path) + x_loaded, model_loaded, number_loaded = torch.load(path) + self.assertRtolEqual(x.cpu(), x_loaded) + self.assertExpectedInline(str(model), str(model_loaded)) self.assertTrue(number, number_loaded) def test_save_value(self): diff --git a/torch_npu/utils/serialization.py b/torch_npu/utils/serialization.py index f7626f7332d..c9a2c5ce77b 100644 --- a/torch_npu/utils/serialization.py +++ b/torch_npu/utils/serialization.py @@ -31,31 +31,24 @@ def to_cpu(data): list_value = list(value) to_cpu(list_value) data[i] = tuple(list_value) - elif isinstance(value, string_classes): continue - elif isinstance(value, (container_abcs.Sequence, container_abcs.Mapping)): to_cpu(value) - - elif isinstance(value, torch.Tensor): + elif isinstance(value, torch.Tensor) or isinstance(value, nn.Module): data[i] = value.cpu() - if isinstance(data, container_abcs.Mapping): for key, value in data.items(): if isinstance(value, tuple): list_value = list(value) to_cpu(list_value) data[key] = tuple(list_value) - elif isinstance(value, (container_abcs.Sequence, container_abcs.Mapping)): to_cpu(value) - - elif isinstance(value, torch.Tensor): + elif isinstance(value, torch.Tensor) or isinstance(value, nn.Module): data[key] = value.cpu() - def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=False): """Saves the input data into a file. -- Gitee