diff --git a/test/test_api/test_serialization.py b/test/test_api/test_serialization.py index 73a2f64d716ee5683956f5fb13d1462c32e4f193..0289bfa3f2dd9cf00aefe845b17d5a0cd68ab883 100644 --- a/test/test_api/test_serialization.py +++ b/test/test_api/test_serialization.py @@ -64,10 +64,13 @@ class TestSerialization(TestCase): self.assertRtolEqual(x.cpu(), x_loaded.cpu()) self.assertTrue(number, number_loaded) - def test_save_error(self, device="npu"): - a = 44 - with self.assertRaisesRegex(RuntimeError, "torch.save received invalid input."): - out = torch.save(a, 'a.pth') + def test_save_value(self, device="npu"): + x = 44 + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'data.pt') + torch.save(x, path) + x_loaded = torch.load(path) + self.assertTrue(x, x_loaded) def test_serialization_model(self, device="npu"): with tempfile.TemporaryDirectory() as tmpdir: diff --git a/torch_npu/utils/serialization.py b/torch_npu/utils/serialization.py index ea859e5399a2d227800b26bccd45e0b856d69a5b..9c2ba25d04c7279e1c04babdb4e47d3b2179ac59 100644 --- a/torch_npu/utils/serialization.py +++ b/torch_npu/utils/serialization.py @@ -87,8 +87,9 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne elif isinstance(obj, nn.Module): obj = obj.cpu() se.save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization) + else: - raise RuntimeError('torch.save received invalid input.') + se.save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization) def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): """Loads data previously saved with the `save()` API.