diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_multi_api_accuracy_checker.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_multi_api_accuracy_checker.py index d74a29ec73e7d97646cc3a49983b5a34ba2731d0..2337f1c1954ef0eb547e742e9801dcfcf4b03e3a 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_multi_api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_multi_api_accuracy_checker.py @@ -188,48 +188,6 @@ class TestMultiApiAccuracyChecker(unittest.TestCase): self.assertEqual(result, backward_output_list) - @patch('msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker.tqdm') - @patch('multiprocessing.Process') - @patch('multiprocessing.Queue') - @patch('msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker.logger') - def test_run_and_compare(self, mock_logger, mock_queue_class, mock_process_class, mock_tqdm): - # 模拟进程和队列 - # 创建一个假的进度队列 - mock_queue = MagicMock() - # 设置进度队列的 get 方法,每次返回 1,总共返回 len(self.checker.api_infos) 次 - mock_queue.get.side_effect = [1] * len(self.checker.api_infos) - mock_queue_class.return_value = mock_queue - - # 创建模拟的进程列表 - mock_processes = [] - for _ in self.args.device_id: - mock_process = MagicMock() - mock_process.is_alive.return_value = False # 模拟进程已完成 - mock_process.exitcode = 0 # 模拟进程正常退出 - mock_process.pid = 12345 # 模拟进程ID - mock_processes.append(mock_process) - - # 设置 Process 的 side_effect,每次调用返回不同的进程对象 - mock_process_class.side_effect = mock_processes - - # 模拟 tqdm - mock_pbar = MagicMock() - mock_tqdm.return_value.__enter__.return_value = mock_pbar - - # 运行方法 - self.checker.run_and_compare() - - # 验证进程被正确创建 - self.assertEqual(mock_process_class.call_count, len(self.args.device_id)) - - # 验证进度条被正确初始化 - mock_tqdm.assert_called_once_with(total=len(self.checker.api_infos), desc="Total Progress", ncols=100) - - # 验证进度队列的 get 方法被正确调用 - self.assertEqual(mock_queue.get.call_count, len(self.checker.api_infos)) - - # 验证进度条的 update 方法被正确调用 - self.assertEqual(mock_pbar.update.call_count, len(self.checker.api_infos)) @patch('msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker.context') def test_process_on_device_api_not_unique(self, mock_context):