diff --git a/component/taskd/tests/ut/api/test_taskd_wokrer_api.py b/component/taskd/tests/ut/api/test_taskd_wokrer_api.py index 2a5e9de1e2ae5f7a171ecbc61b1b9c84f29b183b..49bbac856cb778917c20bebe1926b723fa1fd69a 100644 --- a/component/taskd/tests/ut/api/test_taskd_wokrer_api.py +++ b/component/taskd/tests/ut/api/test_taskd_wokrer_api.py @@ -15,13 +15,16 @@ # limitations under the License. # ============================================================================== import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock -from taskd.api.taskd_worker_api import init_taskd_worker +from taskd.api.taskd_worker_api import init_taskd_worker, start_taskd_worker + +taskd_worker = None +run_log = MagicMock() class WorkerTestCase(unittest.TestCase): - def test_init_taskd_worker_success(self, mock_worker): + def test_init_taskd_worker_success(self): rank_id = 'not_an_int' upper_limit = 5000 result = init_taskd_worker(rank_id, upper_limit) @@ -39,6 +42,27 @@ class WorkerTestCase(unittest.TestCase): result = init_taskd_worker(rank_id, upper_limit) self.assertFalse(result) + def test_worker_not_initialized(self): + result = start_taskd_worker() + self.assertEqual(result, False) + + @patch('taskd.api.taskd_worker_api.taskd_worker') + def test_worker_start_success(self, mock_worker): + global taskd_worker + taskd_worker = mock_worker + mock_worker.start.return_value = True + result = start_taskd_worker() + self.assertEqual(result, True) + mock_worker.start.assert_called_once() + + @patch('taskd.api.taskd_worker_api.taskd_worker') + def test_worker_start_failure(self, mock_worker): + global taskd_worker + taskd_worker = mock_worker + mock_worker.start.side_effect = Exception("Test exception") + result = start_taskd_worker() + self.assertEqual(result, False) + if __name__ == '__main__': unittest.main()