From ee8f6a211a7ed3323e617da4c36e09908a3748e6 Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Mon, 28 Feb 2022 11:10:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9save=E7=9A=84else=E5=88=86?= =?UTF-8?q?=E6=94=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_api/test_serialization.py | 11 +++++++---- torch_npu/utils/serialization.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/test/test_api/test_serialization.py b/test/test_api/test_serialization.py index 73a2f64d716..0289bfa3f2d 100644 --- a/test/test_api/test_serialization.py +++ b/test/test_api/test_serialization.py @@ -64,10 +64,13 @@ class TestSerialization(TestCase): self.assertRtolEqual(x.cpu(), x_loaded.cpu()) self.assertTrue(number, number_loaded) - def test_save_error(self, device="npu"): - a = 44 - with self.assertRaisesRegex(RuntimeError, "torch.save received invalid input."): - out = torch.save(a, 'a.pth') + def test_save_value(self, device="npu"): + x = 44 + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'data.pt') + torch.save(x, path) + x_loaded = torch.load(path) + self.assertTrue(x, x_loaded) def test_serialization_model(self, device="npu"): with tempfile.TemporaryDirectory() as tmpdir: diff --git a/torch_npu/utils/serialization.py b/torch_npu/utils/serialization.py index ea859e5399a..9c2ba25d04c 100644 --- a/torch_npu/utils/serialization.py +++ b/torch_npu/utils/serialization.py @@ -87,8 +87,9 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne elif isinstance(obj, nn.Module): obj = obj.cpu() se.save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization) + else: - raise RuntimeError('torch.save received invalid input.') + se.save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization) def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): """Loads data previously saved with the `save()` API. -- Gitee