diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 2aeb585fc6c28fe2643c8f2d491346fc23a1c419..55a3687299f643f7cd30ace3a78368a6f209c007 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -388,26 +388,6 @@ def load_pkl(pt_path): return pt -def save_api_data(api_data): - """Save data to io stream""" - try: - io_buff = io.BytesIO() - torch.save(api_data, io_buff) - except Exception as e: - raise RuntimeError("save api_data to io_buff failed") from e - return io_buff - - -def load_api_data(api_data_bytes): - """Load data from bytes stream""" - try: - buffer = io.BytesIO(api_data_bytes) - buffer = torch.load(buffer, map_location="cpu", weights_only=False) - except Exception as e: - raise RuntimeError("load api_data from bytes failed") from e - return buffer - - def is_recomputation(): """Check if the current operation is in the re-computation phase. diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py index c28557f8fca6219c7bafcf5d676f89172a1efb96..709d252f39640578be2dab9edf390862efed8e9a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py @@ -31,8 +31,6 @@ from msprobe.pytorch.common.utils import ( print_rank_0, load_pt, save_pt, - save_api_data, - load_api_data, save_pkl, load_pkl ) @@ -234,41 +232,6 @@ class TestSavePT(unittest.TestCase): self.assertIn("save pt file temp_tensor.pt failed", str(context.exception)) -class TestSaveApiData(unittest.TestCase): - - def test_save_api_data_success(self): - api_data = {"key": "value"} - io_buff = save_api_data(api_data) - self.assertIsInstance(io_buff, io.BytesIO) - io_buff.seek(0) - loaded_data = torch.load(io_buff) - self.assertEqual(loaded_data, api_data) - - def test_save_api_data_failure(self): - api_data = MagicMock() - with patch('torch.save', side_effect=Exception("save error")): - with self.assertRaises(RuntimeError) as context: - save_api_data(api_data) - self.assertIn("save api_data to io_buff failed", str(context.exception)) - - -class TestLoadApiData(unittest.TestCase): - - def test_load_api_data_success(self): - mock_tensor = torch.tensor([1, 2, 3]) - buffer = io.BytesIO() - torch.save(mock_tensor, buffer) - buffer.seek(0) - result = load_api_data(buffer.read()) - self.assertTrue(torch.equal(result, mock_tensor)) - - def test_load_api_data_failure(self): - invalid_bytes = b'not a valid tensor' - with self.assertRaises(RuntimeError) as context: - load_api_data(invalid_bytes) - self.assertIn("load api_data from bytes failed", str(context.exception)) - - class TestSavePkl(unittest.TestCase): def setUp(self):