From 4647e77c367b4359cb1b337003feff40717ec70a Mon Sep 17 00:00:00 2001 From: pxp1 <958876660@qq.com> Date: Thu, 7 Aug 2025 15:10:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E7=94=A8=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/pytorch/common/utils.py | 20 ---------- .../test/pytorch_ut/common/test_pt_utils.py | 37 ------------------- 2 files changed, 57 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 2aeb585fc6..55a3687299 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -388,26 +388,6 @@ def load_pkl(pt_path): return pt -def save_api_data(api_data): - """Save data to io stream""" - try: - io_buff = io.BytesIO() - torch.save(api_data, io_buff) - except Exception as e: - raise RuntimeError("save api_data to io_buff failed") from e - return io_buff - - -def load_api_data(api_data_bytes): - """Load data from bytes stream""" - try: - buffer = io.BytesIO(api_data_bytes) - buffer = torch.load(buffer, map_location="cpu", weights_only=False) - except Exception as e: - raise RuntimeError("load api_data from bytes failed") from e - return buffer - - def is_recomputation(): """Check if the current operation is in the re-computation phase. diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py index c28557f8fc..709d252f39 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py @@ -31,8 +31,6 @@ from msprobe.pytorch.common.utils import ( print_rank_0, load_pt, save_pt, - save_api_data, - load_api_data, save_pkl, load_pkl ) @@ -234,41 +232,6 @@ class TestSavePT(unittest.TestCase): self.assertIn("save pt file temp_tensor.pt failed", str(context.exception)) -class TestSaveApiData(unittest.TestCase): - - def test_save_api_data_success(self): - api_data = {"key": "value"} - io_buff = save_api_data(api_data) - self.assertIsInstance(io_buff, io.BytesIO) - io_buff.seek(0) - loaded_data = torch.load(io_buff) - self.assertEqual(loaded_data, api_data) - - def test_save_api_data_failure(self): - api_data = MagicMock() - with patch('torch.save', side_effect=Exception("save error")): - with self.assertRaises(RuntimeError) as context: - save_api_data(api_data) - self.assertIn("save api_data to io_buff failed", str(context.exception)) - - -class TestLoadApiData(unittest.TestCase): - - def test_load_api_data_success(self): - mock_tensor = torch.tensor([1, 2, 3]) - buffer = io.BytesIO() - torch.save(mock_tensor, buffer) - buffer.seek(0) - result = load_api_data(buffer.read()) - self.assertTrue(torch.equal(result, mock_tensor)) - - def test_load_api_data_failure(self): - invalid_bytes = b'not a valid tensor' - with self.assertRaises(RuntimeError) as context: - load_api_data(invalid_bytes) - self.assertIn("load api_data from bytes failed", str(context.exception)) - - class TestSavePkl(unittest.TestCase): def setUp(self): -- Gitee