From 5c7ba275e982f4d373fc804576a7b4014aefd509 Mon Sep 17 00:00:00 2001 From: kai-ma Date: Thu, 24 Jul 2025 17:01:23 +0800 Subject: [PATCH] add ut for base,common --- msprobe/test/ST/run_st.py | 0 .../component_ut/test_manager_component.py | 262 ++++++++++++++++++ .../service_ut/test_manager_service.py | 79 ++++++ msprobe/test/UT/base_ut/test_cmd.py | 86 ++++++ msprobe/test/UT/base_ut/test_config.py | 135 +++++++++ msprobe/test/UT/common_ut/test_ascend.py | 0 msprobe/test/UT/common_ut/test_cli.py | 47 ++++ msprobe/test/UT/common_ut/test_validation.py | 241 ++++++++++++++++ 8 files changed, 850 insertions(+) create mode 100644 msprobe/test/ST/run_st.py create mode 100644 msprobe/test/UT/base_ut/component_ut/test_manager_component.py create mode 100644 msprobe/test/UT/base_ut/service_ut/test_manager_service.py create mode 100644 msprobe/test/UT/base_ut/test_cmd.py create mode 100644 msprobe/test/UT/base_ut/test_config.py create mode 100644 msprobe/test/UT/common_ut/test_ascend.py create mode 100644 msprobe/test/UT/common_ut/test_cli.py create mode 100644 msprobe/test/UT/common_ut/test_validation.py diff --git a/msprobe/test/ST/run_st.py b/msprobe/test/ST/run_st.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/msprobe/test/UT/base_ut/component_ut/test_manager_component.py b/msprobe/test/UT/base_ut/component_ut/test_manager_component.py new file mode 100644 index 0000000000..c0c9eabfe1 --- /dev/null +++ b/msprobe/test/UT/base_ut/component_ut/test_manager_component.py @@ -0,0 +1,262 @@ +import unittest +from unittest.mock import MagicMock + +from msprobe.base.component.manager import BaseComponent, Component, ConsumerComp, ProducerComp, Scheduler +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException + + +class TestBaseComponent: + def test_initialization(self): + component = BaseComponent(priority=200) + assert component.priority == 200 + assert component.is_activated is False + + def test_do_activate(self): + component = BaseComponent() + component.activate = MagicMock() + assert component.is_activated is False + + component.do_activate() + assert component.is_activated is True + component.activate.assert_called_once() + + component.do_activate() + component.activate.assert_called_once() + + def test_do_deactivate(self): + component = BaseComponent() + component.deactivate = MagicMock() + component.activated = True + assert component.is_activated is True + + component.do_deactivate() + assert component.is_activated is False + component.deactivate.assert_called_once() + + component.do_deactivate() + component.deactivate.assert_called_once() + + def test_activate_does_not_change_state_directly(self): + component = BaseComponent() + component.activate() + assert component.is_activated is False + + def test_deactivate_does_not_change_state_directly(self): + component = BaseComponent() + component.activated = True + component.deactivate() + assert component.is_activated is True + + +class ConcreteProducerComp(ProducerComp): + def __init__(self, priority): + super(ConcreteProducerComp, self).__init__(priority) + self._data_generated = False + + def load_data(self): + if not self._data_generated: + self._data_generated = True + return "generated_data" + return None + + +class ConcreteConsumerComp(ConsumerComp): + def __init__(self, priority): + super(ConcreteConsumerComp, self).__init__(priority) + + def consume(self, packages): + print("Consuming data:", packages) + + +class HybridComp(ProducerComp, ConsumerComp): + def __init__(self, priority): + super(HybridComp, self).__init__(priority) + self._data_generated = False + + def load_data(self): + if not self._data_generated: + self._data_generated = True + return "generated_data" + return None + + def consume(self, packages): + print("Consuming:", packages) + + +class TestProducerComp(unittest.TestCase): + def setUp(self): + self.producer = ConcreteProducerComp(priority=100) + self.scheduler_mock = MagicMock() + self.producer.scheduler = self.scheduler_mock + + self.producer.activate = MagicMock() + self.producer.deactivate = MagicMock() + self.producer.publish = MagicMock() + + def test_do_activate(self): + self.assertFalse(self.producer.is_activated) + self.producer.do_activate() + self.assertTrue(self.producer.is_activated) + self.producer.activate.assert_called_once() + + def test_do_deactivate(self): + self.producer.activated = True + self.producer.do_deactivate() + self.assertFalse(self.producer.is_activated) + self.producer.deactivate.assert_called_once() + + def test_retrieve(self): + self.producer.publish("some_data", msg_id=1) + self.assertEqual(len(self.producer.output_buffer), 0) + + def test_do_load_data_when_output_buffer_is_none(self): + self.producer.load_data = MagicMock(return_value="generated_data") + self.producer.do_load_data() + self.producer.load_data.assert_called_once() + self.producer.publish.assert_called_once_with("generated_data") + + def test_do_load_data_when_output_buffer_is_not_none(self): + self.producer.output_buffer = ["some_data"] + self.producer.do_load_data() + self.producer.publish.assert_not_called() + + +class TestConsumerComp(unittest.TestCase): + def setUp(self): + self.producer = ConcreteProducerComp(priority=1) + self.consumer = ConcreteConsumerComp(priority=2) + self.comp_a = HybridComp(priority=100) + self.comp_b = HybridComp(priority=200) + self.comp_c = HybridComp(priority=300) + self.consumer.consume = MagicMock() + + def test_do_consume_with_empty_dependencies(self): + self.consumer.dependencies = {MagicMock(): None} + self.consumer.do_consume() + self.consumer.consume.assert_not_called() + + def test_do_consume_with_filled_dependencies(self): + mock_producer = MagicMock() + package_data = [mock_producer, "mock_data", 1] + + self.consumer.dependencies = {mock_producer: package_data} + self.consumer.do_consume() + + self.consumer.consume.assert_called_once_with([package_data]) + self.assertEqual(self.consumer.dependencies[mock_producer], None) + + def test_do_consume_partial_dependencies(self): + mock_producer1 = MagicMock() + mock_producer2 = MagicMock() + package_data = [mock_producer1, "mock_data", 1] + + self.consumer.dependencies = {mock_producer1: package_data, mock_producer2: None} + self.consumer.do_consume() + self.consumer.consume.assert_not_called() + + def test_subscribe_valid(self): + self.consumer.subscribe(self.producer) + self.assertIn(self.consumer, self.producer.get_subscribers()) + + def test_subscribe_invalid_type(self): + with self.assertRaises(MsprobeException): + self.consumer.subscribe(self.consumer) + + def test_no_cycle(self): + self.comp_a.subscribe(self.comp_b) + self.comp_b.subscribe(self.comp_c) + try: + self.comp_c.subscribe(self.comp_a) + self.assertTrue(True) + except MsprobeException as e: + self.fail(f"Unexpected cycle detection exception: {e}") + + def test_already_subscribed(self): + self.comp_a.subscribe(self.comp_b) + self.comp_b.subscribe(self.comp_c) + self.comp_c.subscribe(self.comp_a) + self.assertEqual(len(self.comp_c.dependencies), 1) + + def test_multiple_cycles(self): + self.comp_a.subscribe(self.comp_b) + self.comp_b.subscribe(self.comp_c) + self.comp_c.subscribe(self.comp_a) + with self.assertRaises(MsprobeException) as context: + self.comp_a.subscribe(self.comp_c) + self.assertIn(MsgConst.RISK_ALERT, str(context.exception)) + + def test_on_receive(self): + package = [self.producer, "test_data", 0] + self.consumer.on_receive(package) + self.assertEqual(self.consumer.dependencies[self.producer], package) + + def test_get_empty_dependencies(self): + self.consumer.subscribe(self.producer) + self.assertIn(self.producer, self.consumer.get_empty_dependencies()) + + def test_do_consume(self): + self.consumer.subscribe(self.producer) + package = [self.producer, "test_data", 0] + self.consumer.on_receive(package) + + +class TestRegisterDecorator(unittest.TestCase): + def setUp(self): + Component._component_type_map = {} + + def test_register_decorator(self): + @Component.register("ComponentB") + class ComponentB: + pass + + self.assertIn("ComponentB", Component._component_type_map) + self.assertEqual(Component._component_type_map["ComponentB"], ComponentB) + + def test_get_registered_component(self): + @Component.register("ComponentC") + class ComponentC: + pass + + component = Component.get("ComponentC") + self.assertEqual(component, ComponentC) + + +class TestScheduler(unittest.TestCase): + def setUp(self): + self.scheduler = Scheduler() + self.producer = MagicMock(ProducerComp) + self.consumer = MagicMock(ConsumerComp) + self.producer.is_ready = True + self.consumer.is_activated = False + self.consumer.get_empty_dependencies.return_value = [] + self.consumer.do_consume = MagicMock() + + def test_add_component(self): + self.scheduler.add([self.producer]) + self.assertIn(self.producer, self.scheduler.comp_ref) + self.assertEqual(self.scheduler.comp_ref[self.producer], 1) + + def test_remove_component(self): + self.scheduler.add([self.producer]) + self.scheduler.remove([self.producer]) + self.assertNotIn(self.producer, self.scheduler.comp_ref) + + def test_schedule_consumer_with_unready_dependencies(self): + dependency_mock = MagicMock() + dependency_mock.is_ready = False + dependency_mock.do_load_data = MagicMock() + + self.consumer.get_empty_dependencies.return_value = [dependency_mock] + self.scheduler._schedule_consumer(self.consumer) + self.consumer.do_consume.assert_not_called() + dependency_mock.do_load_data.assert_called_once() + + def test_schedule_consumer_with_ready_dependencies(self): + dependency_mock = MagicMock() + dependency_mock.is_ready = True + dependency_mock.do_load_data = MagicMock() + + self.consumer.get_empty_dependencies.return_value = [dependency_mock] + self.scheduler._schedule_consumer(self.consumer) + dependency_mock.do_load_data.assert_called_once() diff --git a/msprobe/test/UT/base_ut/service_ut/test_manager_service.py b/msprobe/test/UT/base_ut/service_ut/test_manager_service.py new file mode 100644 index 0000000000..84fc7b3500 --- /dev/null +++ b/msprobe/test/UT/base_ut/service_ut/test_manager_service.py @@ -0,0 +1,79 @@ +import unittest +from unittest.mock import MagicMock, call, create_autospec, patch + +from msprobe.base import BaseComponent, BaseService, Scheduler, Service +from msprobe.utils.constants import CfgConst, CmdConst + + +class TestService(unittest.TestCase): + def setUp(self): + Service._services_map.clear() + + @patch("msprobe.base.service.manager.load_json") + @patch("msprobe.base.service.manager.valid_task") + def test_service_initialization_via_config(self, mock_valid_task, mock_load_json): + mock_load_json.return_value = {CfgConst.TASK: CfgConst.TASK_STAT} + mock_valid_task.return_value = CfgConst.TASK_STAT + mock_service_cls = MagicMock() + Service._services_map[CmdConst.DUMP] = mock_service_cls + cmd_namespace = MagicMock() + cmd_namespace.config_path = "dummy_path" + service = Service(cmd_namespace=cmd_namespace, key="value") + mock_load_json.assert_called_once_with("dummy_path") + mock_service_cls.assert_called_once_with(cmd_namespace=cmd_namespace, key="value") + self.assertEqual(service.service_instance, mock_service_cls.return_value) + + def test_service_registration(self): + @Service.register("test_service") + class TestServiceImpl: + pass + + self.assertIs(Service._services_map["test_service"], TestServiceImpl) + + @patch("msprobe.base.service.manager.load_json") + @patch("msprobe.base.service.manager.valid_task") + def test_service_method_delegation(self, mock_valid_task, mock_load_json): + mock_instance = MagicMock() + mock_instance.target_method = MagicMock(return_value="result") + mock_service_cls = MagicMock(return_value=mock_instance) + Service._services_map[CmdConst.DUMP] = mock_service_cls + mock_load_json.return_value = {CfgConst.TASK: CfgConst.TASK_STAT} + mock_valid_task.return_value = CfgConst.TASK_STAT + cmd_namespace = MagicMock() + cmd_namespace.config_path = "valid_path" + service = Service(cmd_namespace=cmd_namespace) + result = service.target_method("arg", kw=456) + mock_instance.target_method.assert_called_once_with("arg", kw=456) + self.assertEqual(result, "result") + mock_service_cls.assert_called_once_with(cmd_namespace=cmd_namespace) + + +class TestBaseService(unittest.TestCase): + @patch.object(Scheduler, "add") + @patch.object(Scheduler, "remove") + def test_full_lifecycle(self, mock_remove, mock_add): + class TestService(BaseService): + def construct(self): + self.high_pri = create_autospec(BaseComponent, name="high_pri") + self.high_pri.priority = 1 + self.low_pri = create_autospec(BaseComponent, name="low_pri") + self.low_pri.priority = 2 + self.non_comp = "non comp" + + service = TestService() + service.start() + mock_add.assert_called_once() + + @patch.object(BaseService, "init_start") + @patch.object(BaseService, "finalize_start") + def test_hook_execution_order(self, mock_final, mock_init): + class HookTestService(BaseService): + def construct(self): + pass + + service = HookTestService() + service.start() + mock_init.assert_called_once() + mock_final.assert_called_once() + self.assertEqual(mock_init.call_args_list[0], call()) + self.assertEqual(mock_final.call_args_list[-1], call()) diff --git a/msprobe/test/UT/base_ut/test_cmd.py b/msprobe/test/UT/base_ut/test_cmd.py new file mode 100644 index 0000000000..d9af9937ba --- /dev/null +++ b/msprobe/test/UT/base_ut/test_cmd.py @@ -0,0 +1,86 @@ +import unittest +from argparse import RawTextHelpFormatter +from unittest.mock import MagicMock, patch + +from msprobe.base import BaseCommand, Command +from msprobe.utils.constants import CmdConst, MsgConst +from msprobe.utils.exceptions import MsprobeException + + +class TestCommandRegistration(unittest.TestCase): + def setUp(self): + Command._cmd_map.clear() + + def test_register_command(self): + parent_cmd = None + cmd_name = "test" + + @Command.register(parent_cmd, cmd_name) + class TestCommand(BaseCommand): + pass + + self.assertIn(parent_cmd, Command._cmd_map) + self.assertIn(cmd_name, Command._cmd_map[parent_cmd]) + self.assertIs(Command._cmd_map[parent_cmd][cmd_name], TestCommand) + + def test_get_command(self): + parent1, parent2 = "parent1", "parent2" + cmd1, cmd2 = "cmd1", "cmd2" + + @Command.register(parent1, cmd1) + class Cmd1(BaseCommand): + pass + + @Command.register(parent2, cmd2) + class Cmd2(BaseCommand): + pass + + self.assertEqual(Command.get(parent1), {cmd1: Cmd1}) + self.assertEqual(Command.get(parent2), {cmd2: Cmd2}) + self.assertEqual(Command.get("invalid_parent"), {}) + + +class TestBaseCommand(unittest.TestCase): + class ConcreteCommand(BaseCommand): + def add_arguments(self, parse): + pass + + def setUp(self): + self.cmd = self.ConcreteCommand() + self.cmd.subcommand_level = 0 + + @patch("msprobe.base.cmd.argv", ["script", "arg1", "arg2"]) + def test_service_key_valid(self): + self.cmd.subcommand_level = 1 + self.assertEqual(self.cmd.service_key, "arg1") + + @patch("msprobe.base.cmd.argv", ["script"]) + def test_service_key_insufficient_args(self): + self.cmd.subcommand_level = 1 + self.assertIsNone(self.cmd.service_key) + + def test_service_key_invalid_level(self): + self.cmd.subcommand_level = "invalid" + with self.assertRaises(MsprobeException) as cm: + _ = self.cmd.service_key + self.assertEqual(str(cm.exception), f"{MsgConst.INVALID_ARGU} Subcommand level must be a positive integer.") + + @patch("msprobe.base.Command.get") + def test_build_parser_with_subcommands(self, mock_get): + class MockSubCommand: + @classmethod + def add_arguments(cls, parser): + pass + + mock_get.side_effect = [{"subcmd": MockSubCommand}, {}] + parent_parser = MagicMock() + fake_subparser = MagicMock() + subparsers = MagicMock() + parent_parser.add_subparsers.return_value = subparsers + subparsers.add_parser.return_value = fake_subparser + self.cmd.subcommand_level = 0 + self.cmd.build_parser(parent_parser, MagicMock()) + parent_parser.add_subparsers.assert_called_once_with(dest="L1command") + subparsers.add_parser.assert_called_once_with( + name="subcmd", help=CmdConst.HELP_SERVICE_MAP.get("subcmd"), formatter_class=RawTextHelpFormatter + ) diff --git a/msprobe/test/UT/base_ut/test_config.py b/msprobe/test/UT/base_ut/test_config.py new file mode 100644 index 0000000000..401711d6d9 --- /dev/null +++ b/msprobe/test/UT/base_ut/test_config.py @@ -0,0 +1,135 @@ +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.base import BaseConfig, Dict2Class +from msprobe.utils.constants import CfgConst, MsgConst +from msprobe.utils.exceptions import MsprobeException + + +class ConcreteConfig(BaseConfig): + def check_config(self): + pass + + +class TestBaseConfig(unittest.TestCase): + def setUp(self): + self.mock_config = { + CfgConst.TASK: "test_task", + "test_task": {"key": "value"}, + CfgConst.FRAMEWORK: "test_framework", + CfgConst.STEP: [], + CfgConst.RANK: [], + CfgConst.LEVEL: [CfgConst.LEVEL_API], + CfgConst.LOG_LEVEL: "info", + CfgConst.SEED: None, + } + self.config_path = "dummy_path.json" + + @patch("msprobe.base.config.load_json") + def test_initialization(self, mock_load_json): + mock_load_json.return_value = self.mock_config + config = ConcreteConfig(self.config_path, task="test_task", step=[], level=[]) + self.assertEqual(config.config_path, self.config_path) + self.assertEqual(config.config, self.mock_config) + self.assertEqual(config.task, "test_task") + self.assertEqual(config.step, []) + self.assertEqual(config.level, []) + mock_load_json.assert_called_once_with(self.config_path) + + @patch("msprobe.base.config.load_json") + def test_common_check_calls(self, mock_load_json): + mock_load_json.return_value = self.mock_config + config = ConcreteConfig(self.config_path) + + with patch.multiple( + "msprobe.base.config", + valid_task=MagicMock(return_value="test_task"), + valid_framework=MagicMock(return_value="valid_framework"), + valid_step_or_rank=MagicMock(side_effect=lambda x: x), + valid_level=MagicMock(return_value=["valid_level"]), + valid_log_level=MagicMock(return_value="valid_log_level"), + valid_seed=MagicMock(return_value=42), + ) as mocks: + config._common_check() + + self.assertEqual(config.config[CfgConst.TASK], "test_task") + self.assertEqual(config.config[CfgConst.FRAMEWORK], "valid_framework") + self.assertEqual(config.config[CfgConst.STEP], []) + self.assertEqual(config.config[CfgConst.RANK], []) + self.assertEqual(config.config[CfgConst.LEVEL], ["valid_level"]) + self.assertEqual(config.config[CfgConst.LOG_LEVEL], "valid_log_level") + self.assertEqual(config.config[CfgConst.SEED], 42) + + @patch("msprobe.base.config.load_json") + def test_get_task_dict_success(self, mock_load_json): + mock_load_json.return_value = {CfgConst.TASK: "existing_task", "existing_task": {"key": "value"}} + config = ConcreteConfig(self.config_path) + config._get_task_dict() + self.assertEqual(config.task_config, {"key": "value"}) + + @patch("msprobe.base.config.load_json") + def test_get_task_dict_raises_exception(self, mock_load_json): + mock_load_json.return_value = {CfgConst.TASK: "non_existing_task"} + config = ConcreteConfig(self.config_path) + with self.assertRaises(MsprobeException) as context: + config._get_task_dict() + self.assertIn(f'Missing dictionary for key "non_existing_task".', context.exception.error_msg) + + @patch("msprobe.base.config.load_json") + def test_update_config(self, mock_load_json): + mock_load_json.return_value = self.mock_config + config = ConcreteConfig(self.config_path) + test_dict = {} + mock_check = MagicMock(return_value="checked_value") + config._update_config(test_dict, "test_key", mock_check, "test_value") + mock_check.assert_called_once_with("test_value") + self.assertEqual(test_dict["test_key"], "checked_value") + + @patch("msprobe.base.config.load_json") + def test_check_config_wrapper(self, mock_load_json): + mock_load_json.return_value = self.mock_config + config = ConcreteConfig(self.config_path) + with patch.object(config, "_common_check") as mock_common_check, patch.object( + config, "check_config" + ) as mock_check_config: + config.check_config() + mock_common_check.assert_called_once() + mock_check_config.assert_called_once() + self.assertEqual(config.task_config, {"key": "value"}) + + +class TestDict2Class(unittest.TestCase): + def test_basic_conversion(self): + data = {"name": "test", "value": 10} + obj = Dict2Class(data) + self.assertEqual(obj.name, "test") + self.assertEqual(obj.value, 10) + + def test_nested_dict_conversion(self): + data = {"nested": {"key": "value"}} + obj = Dict2Class(data) + self.assertIsInstance(obj.nested, Dict2Class) + self.assertEqual(obj.nested.key, "value") + + def test_service_key_processing(self): + data = {CfgConst.TASK: "special", "special": {"input": [[224, 224], "path/to/input"], "param": 5}} + obj = Dict2Class(data) + self.assertEqual(obj.input_shape, [224, 224]) + self.assertEqual(obj.input_path, "path/to/input") + self.assertEqual(obj.param, 5) + + def test_max_recursion_depth(self): + data = {} + current = data + for _ in range(MsgConst.MAX_RECURSION_DEPTH + 1): + current["nested"] = {} + current = current["nested"] + with self.assertRaises(MsprobeException) as context: + Dict2Class(data) + self.assertIn(f"Maximum recursion depth of {MsgConst.MAX_RECURSION_DEPTH}", str(context.exception)) + + def test_missing_attribute(self): + obj = Dict2Class({"existing": 1}) + with self.assertRaises(MsprobeException) as context: + _ = obj.non_existing + self.assertIn("has no attribute non_existing", str(context.exception)) diff --git a/msprobe/test/UT/common_ut/test_ascend.py b/msprobe/test/UT/common_ut/test_ascend.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/msprobe/test/UT/common_ut/test_cli.py b/msprobe/test/UT/common_ut/test_cli.py new file mode 100644 index 0000000000..eb21888134 --- /dev/null +++ b/msprobe/test/UT/common_ut/test_cli.py @@ -0,0 +1,47 @@ +import unittest +from unittest.mock import ANY, MagicMock, PropertyMock, call, patch + +from msprobe.common.cli import MainCommand, MsprobeException +from msprobe.utils.constants import CmdConst + + +class TestMainCommand(unittest.TestCase): + def setUp(self): + self.main_cmd = MainCommand() + self.mock_second_commands = {"cmd1": MagicMock(), "cmd2": MagicMock()} + self.main_cmd.second_commands = self.mock_second_commands + + @patch("msprobe.common.cli.ArgumentParser") + def test_init(self, mock_argparse): + main_cmd = MainCommand() + mock_argparse.assert_called_once_with(prog=CmdConst.PROG_NAME, description=ANY, formatter_class=ANY) + self.assertEqual(main_cmd.subcommand_level, 1) + self.assertIsNotNone(main_cmd.parser) + self.assertIsNotNone(main_cmd.subparser) + + def test_register(self): + with patch.object(MainCommand, "service_key", new_callable=PropertyMock) as mock_service_key: + mock_service_key.return_value = "cmd1" + mock_subparser = MagicMock() + self.main_cmd.subparser = mock_subparser + self.main_cmd.register() + expected_calls = [call(name="cmd1", help=None, formatter_class=self.main_cmd.formatter_class)] + mock_subparser.add_parser.assert_has_calls(expected_calls, any_order=True) + self.mock_second_commands["cmd1"].add_arguments.assert_called_once() + self.assertEqual(self.main_cmd.subcommand_level, 2) + + def test_parse(self): + mock_args = MagicMock() + self.main_cmd.parser.parse_args = MagicMock(return_value=mock_args) + result = self.main_cmd.parse() + self.assertEqual(result, mock_args) + self.main_cmd.parser.parse_args.assert_called_once() + + @patch("msprobe.common.cli.Service") + @patch("msprobe.common.cli.argv", ["msprobe", "invalid_service"]) + def test_execute_invalid_service(self, mock_service): + mock_service.get.return_value = False + args = MagicMock() + with self.assertRaises(MsprobeException) as context: + self.main_cmd.execute(args) + self.assertIn(" service is not registered", str(context.exception)) diff --git a/msprobe/test/UT/common_ut/test_validation.py b/msprobe/test/UT/common_ut/test_validation.py new file mode 100644 index 0000000000..191de2ac27 --- /dev/null +++ b/msprobe/test/UT/common_ut/test_validation.py @@ -0,0 +1,241 @@ +import unittest +from argparse import Namespace +from unittest.mock import MagicMock, patch + +from msprobe.common.validation import ( + CheckConfigPath, + CheckExec, + CheckFramework, + SafePath, + check_int_border, + parse_hyphen, + valid_config_path, + valid_exec, + valid_framework, + valid_level, + valid_log_level, + valid_seed, + valid_step_or_rank, + valid_task, +) +from msprobe.utils.exceptions import MsprobeException + + +class TestValidationFunctions(unittest.TestCase): + def setUp(self): + self.mock_cfgconst = MagicMock() + self.mock_cfgconst.ALL_TASK = ["train", "eval", "predict"] + self.mock_cfgconst.ALL_FRAMEWORK = ["tf", "pytorch"] + self.mock_cfgconst.ALL_LEVEL = ["info", "debug", "warning"] + self.patcher = patch.dict( + "sys.modules", + { + "msprobe.utils.constants.CfgConst": self.mock_cfgconst, + "msprobe.utils.constants.PathConst": MagicMock( + SUFFIX_SH=".sh", + SUFFIX_PY=".py", + SUFFIX_OFFLINE_MODEL=(".onnx", ".pb"), + SUFFIX_ONLINE_SCRIPT=(".sh", ".py"), + SUFFIX_JSON=".json", + DIR="dir", + FILE="file", + ), + }, + ) + self.patcher.start() + + def tearDown(self): + self.patcher.stop() + + def test_valid_task_valid(self): + self.assertEqual(valid_task("tensor"), "tensor") + + def test_valid_task_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_task("invalid_task") + self.assertIn("must be one of ", str(cm.exception)) + + def test_valid_task_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_task(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_valid_exec_none(self): + self.assertEqual(valid_exec(None), None) + + def test_valid_exec_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_exec(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + @patch("msprobe.common.validation.is_dir") + @patch("msprobe.common.validation.SafePath") + def test_valid_exec_directory(self, mock_safepath, mock_is_dir): + mock_is_dir.return_value = True + values = "/valid/directory" + result = valid_exec(values) + self.assertEqual(result, [values]) + mock_safepath.assert_called_once() + + def test_valid_exec_bash_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_exec("bash invalid_script.py") + self.assertIn("[ERROR] Parsing failed.", str(cm.exception)) + + def test_valid_exec_python_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_exec("python invalid_script.sh") + self.assertIn("[ERROR] Parsing failed.", str(cm.exception)) + + @patch("msprobe.common.validation.is_file") + @patch("msprobe.common.validation.SafePath") + def test_valid_exec_model_file(self, mock_safepath, mock_is_file): + mock_is_file.return_value = True + values = "model.onnx" + self.assertEqual(valid_exec(values), [values]) + mock_safepath.assert_called_once() + + @patch("msprobe.common.validation.is_dir") + @patch("msprobe.common.validation.is_file") + @patch("msprobe.common.validation.SafePath") + def test_check_exec_action(self, mock_safepath, mock_is_file, mock_is_dir): + mock_is_dir.return_value = False + mock_is_file.return_value = True + action = CheckExec(option_strings=["-e", "--exec"], dest="exec") + mock_namespace = Namespace() + test_values = "valid_script.sh" + with patch.object(SafePath, "check") as mock_check: + mock_check.return_value = test_values[0] + action(None, mock_namespace, test_values) + self.assertEqual(mock_namespace.exec, [test_values]) + mock_is_dir.assert_called_once_with(test_values) + mock_is_file.assert_called_once_with(test_values) + + @patch("msprobe.common.validation.SafePath") + def test_valid_config_path_valid(self, mock_safepath): + mock_safepath.return_value.check.return_value = "valid.json" + result = valid_config_path("config.json") + self.assertEqual(result, "valid.json") + + def test_valid_config_path(self): + self.option_strings = ["-c", "--config"] + self.dest = "config_path" + self.action = CheckConfigPath(option_strings=self.option_strings, dest=self.dest) + test_value = "/valid/path/config.json" + expected_result = "/verified/path/config.json" + mock_namespace = Namespace() + + with patch("msprobe.common.validation.valid_config_path") as mock_validator: + mock_validator.return_value = expected_result + self.action(parser=MagicMock(), namespace=mock_namespace, values=test_value) + mock_validator.assert_called_once_with(test_value) + self.assertEqual(getattr(mock_namespace, self.dest), expected_result) + + def test_valid_framework_valid(self): + self.assertEqual(valid_framework("mindie_llm"), "mindie_llm") + + def test_valid_framework_invalid(self): + self.assertEqual(valid_framework(""), "") + + def test_valid_framework_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_framework(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_valid_framework_more_element_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_framework("invalid_fw") + self.assertIn('[ERROR] invalid argument. "framework" must be one of', str(cm.exception)) + + def test_check_framework(self): + self.option_strings = ["-f", "--framework"] + self.dest = "framework" + self.action = CheckFramework(option_strings=self.option_strings, dest=self.dest) + test_value = "mindie_llm" + mock_namespace = Namespace() + with patch("msprobe.common.validation.valid_framework") as mock_validator: + mock_validator.return_value = test_value + self.action(parser=MagicMock(), namespace=mock_namespace, values=test_value) + mock_validator.assert_called_once_with(test_value) + self.assertEqual(getattr(mock_namespace, self.dest), test_value) + + def test_check_int_border_valid(self): + check_int_border(0, 500000, 1000000) + + def test_valid_check_int_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + check_int_border([0.35]) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_check_int_border_invalid(self): + with self.assertRaises(MsprobeException) as cm: + check_int_border(-1) + self.assertIn("The integer range is limited to ", str(cm.exception)) + with self.assertRaises(MsprobeException): + check_int_border(1000000000000000000) + + def test_parse_hyphen_valid(self): + self.assertEqual(parse_hyphen("100-200"), list(range(100, 201))) + self.assertEqual(parse_hyphen("100-200-2"), list(range(100, 201, 2))) + + def test_parse_hyphen_invalid(self): + with self.assertRaises(MsprobeException): + parse_hyphen("100-200-300-400") + with self.assertRaises(MsprobeException): + parse_hyphen("200-100") + + def test_valid_step_or_rank(self): + self.assertEqual(valid_step_or_rank([10, "20-22", "30-35-2"]), [10, 20, 21, 22, 30, 32, 34]) + + def test_valid_step_or_rank_none(self): + self.assertEqual(valid_step_or_rank([]), []) + + def test_valid_step_or_rank_type_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_step_or_rank(123) + + def test_valid_step_or_rank_invalid(self): + with self.assertRaises(MsprobeException) as cm: + valid_step_or_rank([0.35]) + + def test_valid_level_valid_none(self): + self.assertEqual(valid_level(""), "") + + def test_valid_level_invalid_type(self): + with self.assertRaises(MsprobeException) as cm: + valid_level(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_valid_level_valid(self): + self.assertEqual(valid_level(["L0", "L1"]), ["L0", "L1"]) + + def test_valid_log_level_valid_none(self): + self.assertEqual(valid_log_level(None), None) + + def test_valid_log_level_invalid_type(self): + with self.assertRaises(MsprobeException) as cm: + valid_log_level(123) + self.assertIn("[ERROR] invalid data type.", str(cm.exception)) + + def test_valid_level_invalid(self): + with self.assertRaises(MsprobeException): + valid_level(["invalid_level"]) + + def test_valid_log_level_valid(self): + self.assertEqual(valid_log_level("info"), "info") + + def test_valid_log_level_invalid(self): + with self.assertRaises(MsprobeException): + valid_log_level("invalid") + + def test_valid_seed_valid_none(self): + self.assertEqual(valid_seed(None), None) + + def test_valid_seed_valid(self): + self.assertEqual(valid_seed(42), 42) + + def test_valid_seed_invalid(self): + with self.assertRaises(MsprobeException): + valid_seed("not_an_int") + with self.assertRaises(MsprobeException): + valid_seed(-1) -- Gitee