diff --git a/test/test_api/test_torch/test_serialization.py b/test/test_api/test_torch/test_serialization.py index 6d4b315cf71887add76679ab7842302a5f33f7d1..6d00c022e4b9f51555be0d249985d5d8ea37ce87 100644 --- a/test/test_api/test_torch/test_serialization.py +++ b/test/test_api/test_torch/test_serialization.py @@ -56,13 +56,14 @@ class TestSerialization(TestCase): def test_save_tuple(self): x = torch.randn(5).npu() + model = NpuMNIST().npu() number = 3 with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, 'data.pt') - torch.save((x, number), path) - x_loaded, number_loaded = torch.load(path) - x_loaded = x_loaded.npu() - self.assertRtolEqual(x.cpu(), x_loaded.cpu()) + torch.save((x, model, number), path) + x_loaded, model_loaded, number_loaded = torch.load(path) + self.assertRtolEqual(x.cpu(), x_loaded) + self.assertExpectedInline(str(model), str(model_loaded)) self.assertTrue(number, number_loaded) def test_save_value(self): diff --git a/torch_npu/utils/serialization.py b/torch_npu/utils/serialization.py index f7626f7332d354197bf46836ee178590674c61dc..c9a2c5ce77bad65bf73a90dd27c921ec55e9b45e 100644 --- a/torch_npu/utils/serialization.py +++ b/torch_npu/utils/serialization.py @@ -31,31 +31,24 @@ def to_cpu(data): list_value = list(value) to_cpu(list_value) data[i] = tuple(list_value) - elif isinstance(value, string_classes): continue - elif isinstance(value, (container_abcs.Sequence, container_abcs.Mapping)): to_cpu(value) - - elif isinstance(value, torch.Tensor): + elif isinstance(value, torch.Tensor) or isinstance(value, nn.Module): data[i] = value.cpu() - if isinstance(data, container_abcs.Mapping): for key, value in data.items(): if isinstance(value, tuple): list_value = list(value) to_cpu(list_value) data[key] = tuple(list_value) - elif isinstance(value, (container_abcs.Sequence, container_abcs.Mapping)): to_cpu(value) - - elif isinstance(value, torch.Tensor): + elif isinstance(value, torch.Tensor) or isinstance(value, nn.Module): data[key] = value.cpu() - def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=False): """Saves the input data into a file.