diff --git a/msprobe/test/UT/csrc_ut/CMakeLists.txt b/msprobe/test/UT/csrc_ut/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5fe115c4594493f5f55cf535dde5b527804ab2bf --- /dev/null +++ b/msprobe/test/UT/csrc_ut/CMakeLists.txt @@ -0,0 +1,23 @@ +project(msprobe VERSION 1.0.0 LANGUAGES CXX C) +cmake_minimum_required(VERSION 3.14) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(cpython MODULE REQUIRED) +find_package(gtest MODULE REQUIRED) +find_package(mockcpp MODULE REQUIRED) + +add_executable(msprobe_test) +target_link_libraries(msprobe_test PRIVATE ${gtest_LIBRARIES}) +target_link_libraries(msprobe_test PRIVATE ${mockcpp_LIBRARIES}) +target_link_libraries(msprobe_test PRIVATE msprobe_c) + +target_include_directories(msprobe_test PRIVATE $ENV{PROJECT_ROOT_PATH}/msprobe/csrc) +target_include_directories(msprobe_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_include_directories(msprobe_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mock) + +target_compile_definitions(msprobe_test PRIVATE __RESOURCES_PATH__="${CMAKE_CURRENT_SOURCE_DIR}/../resources") + +file(GLOB_RECURSE SOURCES "*.cpp") +target_sources(msprobe_test PUBLIC ${SOURCES}) diff --git a/msprobe/test/UT/csrc_ut/utils_ut/test_log.cpp b/msprobe/test/UT/csrc_ut/utils_ut/test_log.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/msprobe/test/UT/utils_ut/test_dependencies.py b/msprobe/test/UT/utils_ut/test_dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..6a564aa59a8253ec1594ba8b5e34e1c493291149 --- /dev/null +++ b/msprobe/test/UT/utils_ut/test_dependencies.py @@ -0,0 +1,113 @@ +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +from msprobe.utils.dependencies import DependencyManager, temporary_tf_log_level +from msprobe.utils.exceptions import MsprobeException + + +class TestDependencyManager(unittest.TestCase): + def setUp(self): + DependencyManager._instance = None + self.manager = DependencyManager() + + def tearDown(self): + if "TF_CPP_MIN_LOG_LEVEL" in os.environ: + del os.environ["TF_CPP_MIN_LOG_LEVEL"] + + @patch.dict(os.environ, {"TF_CPP_MIN_LOG_LEVEL": "0"}) + def test_temporary_tf_log_level(self): + @temporary_tf_log_level + def mock_function(): + return os.environ["TF_CPP_MIN_LOG_LEVEL"] + + self.assertEqual(mock_function(), "2") + self.assertEqual(os.environ["TF_CPP_MIN_LOG_LEVEL"], "0") + + @patch("msprobe.utils.dependencies.import_module") + def test_get_tensorflow(self, mock_import_module): + mock_tf = MagicMock() + mock_tf.__version__ = "2.6.5" + mock_rewriter_config = MagicMock() + mock_convert_variables = MagicMock() + + def side_effect(name): + if name == "tensorflow": + return mock_tf + return MagicMock() + + mock_import_module.side_effect = side_effect + sys.modules["tensorflow"] = mock_tf + sys.modules["tensorflow.core.protobuf.rewriter_config_pb2"] = MagicMock(RewriterConfig=mock_rewriter_config) + sys.modules["tensorflow.python.framework.graph_util"] = MagicMock( + convert_variables_to_constants=mock_convert_variables + ) + dm = DependencyManager() + tf, re_writer_config, sm2pb = dm.get_tensorflow() + + self.assertIsNotNone(tf, "TensorFlow is not None") + self.assertEqual(tf, mock_tf) + self.assertEqual(re_writer_config, mock_rewriter_config) + self.assertEqual(sm2pb, mock_convert_variables) + + @patch("msprobe.utils.dependencies.import_module") + def test_import_package_non_tensorflow(self, mock_import): + mock_module = MagicMock() + mock_import.return_value = mock_module + result = self.manager._import_package("abc") + mock_import.assert_called_once_with("abc") + self.assertEqual(result, mock_module) + self.assertIn("abc", self.manager._dependencies) + + @patch.object(DependencyManager, "_import_tensorflow") + def test_import_package_tensorflow(self, mock_import_tf): + mock_tf = MagicMock() + + def simulate_import(): + self.manager._dependencies["tensorflow"] = mock_tf + return mock_tf + + mock_import_tf.side_effect = simulate_import + result = self.manager._import_package("tensorflow") + mock_import_tf.assert_called_once() + self.assertEqual(result, mock_tf) + self.assertIn("tensorflow", self.manager._dependencies) + + @patch("msprobe.utils.dependencies.import_module") + def test_import_tensorflow_wrong_version(self, mock_import): + mock_tf = MagicMock() + mock_tf.__version__ = "2.7.0" + mock_import.return_value = mock_tf + with self.assertRaises(MsprobeException) as context: + self.manager._import_tensorflow() + self.assertIn("Incompatible versions", str(context.exception)) + + @patch("msprobe.utils.dependencies.import_module") + def test_import_tensorflow_environment_reset(self, mock_import): + original_level = "0" + os.environ["TF_CPP_MIN_LOG_LEVEL"] = original_level + mock_tf = MagicMock() + mock_tf.__version__ = "2.6.5" + mock_import.return_value = mock_tf + self.manager._import_tensorflow() + self.assertEqual(os.environ["TF_CPP_MIN_LOG_LEVEL"], original_level) + + @patch("msprobe.utils.dependencies.import_module") + def test_import_package_missing_dependency(self, mock_import): + mock_import.side_effect = ImportError("No module named 'missing_package'") + result = self.manager._import_package("missing_package") + self.assertIsNone(result) + self.assertNotIn("missing_package", self.manager._dependencies) + + @patch("msprobe.utils.dependencies.logger.warning") + @patch("msprobe.utils.dependencies.import_module") + def test_safely_import_decorator(self, mock_import, mock_warning): + mock_import.side_effect = ImportError("Test error") + result = self.manager._import_package("test_package") + self.assertIsNone(result) + mock_warning.assert_called_once_with("test_package is not installed. Please install it if needed.") + mock_warning.reset_mock() + result = self.manager._import_package("test_package") + self.assertIsNone(result) + mock_warning.assert_not_called() diff --git a/msprobe/test/UT/utils_ut/test_env.py b/msprobe/test/UT/utils_ut/test_env.py new file mode 100644 index 0000000000000000000000000000000000000000..28d4428dc210bded92f01c96fc70c491acd81c25 --- /dev/null +++ b/msprobe/test/UT/utils_ut/test_env.py @@ -0,0 +1,102 @@ +import os +import unittest +from unittest import mock + +from msprobe.utils.env import EnvVarManager +from msprobe.utils.exceptions import MsprobeException + + +class TestEnvVarManager(unittest.TestCase): + def setUp(self): + self.manager = EnvVarManager() + self.manager.set_prefix("") + self.env_patcher = mock.patch.dict(os.environ, clear=True) + self.env_patcher.start() + + def tearDown(self): + self.env_patcher.stop() + + def test_singleton_instance(self): + manager1 = EnvVarManager() + manager2 = EnvVarManager() + self.assertIs(manager1, manager2) + + def test_set_prefix(self): + self.manager.set_prefix("TEST_") + self.assertEqual(self.manager.prefix, "TEST_") + + def test_get_existing_var_no_prefix(self): + os.environ["KEY"] = "value" + result = self.manager.get("KEY") + self.assertEqual(result, "value") + + def test_get_existing_var_with_prefix(self): + self.manager.set_prefix("TEST_") + os.environ["TEST_KEY"] = "value" + result = self.manager.get("TEST_KEY") + self.assertEqual(result, "value") + + def test_get_missing_var_required(self): + with self.assertRaises(MsprobeException) as cm: + self.manager.get("MISSING_KEY", required=True) + self.assertIn("MISSING_KEY", str(cm.exception)) + + def test_get_missing_var_optional_with_default(self): + result = self.manager.get("MISSING_KEY", default="default_val", required=False) + self.assertEqual(result, "default_val") + + def test_get_cast_type_success(self): + os.environ["INT_VAL"] = "123" + result = self.manager.get("INT_VAL", cast_type=int) + self.assertEqual(result, 123) + + def test_get_cast_type_failure(self): + os.environ["INVALID_INT"] = "abc" + with self.assertRaises(MsprobeException) as cm: + self.manager.get("INVALID_INT", cast_type=int) + self.assertIn("Failed to cast", str(cm.exception)) + + def test_set_var_with_prefix(self): + self.manager.set_prefix("TEST_") + self.manager.set("NEW_KEY", "value") + self.assertEqual(os.environ["NEW_KEY"], "value") + + def test_delete_existing_var(self): + os.environ["TEST_KEY"] = "value" + self.manager.set_prefix("TEST_") + self.manager.delete("KEY") + self.assertIn("TEST_KEY", os.environ) + + def test_delete_non_existing_var(self): + self.manager.set_prefix("TEST_") + try: + self.manager.delete("NON_EXISTENT") + except Exception: + self.fail("Deleting non-existent variable raised unexpected exception") + + def test_list_all_with_prefix(self): + os.environ.update({"TEST_A": "1", "TEST_B": "2", "OTHER": "3"}) + self.manager.set_prefix("TEST_") + result = self.manager.list_all() + expected = {"TEST_A": "1", "TEST_B": "2"} + self.assertDictEqual(result, expected) + + def test_list_all_without_prefix(self): + os.environ["KEY"] = "value" + result = self.manager.list_all() + self.assertIn("KEY", result) + + @mock.patch("msprobe.utils.log.logger.debug") + def test_logging_on_get(self, mock_debug): + os.environ["LOGGED_KEY"] = "log_value" + self.manager.get("LOGGED_KEY") + mock_debug.assert_called_with("Accessed environment variable LOGGED_KEY, Value: log_value.") + + @mock.patch("msprobe.utils.log.logger.debug") + def test_logging_on_set(self, mock_debug): + self.manager.set("LOGGED_SET", "value") + mock_debug.assert_called_with("Set environment variable LOGGED_SET to value.") + + +if __name__ == "__main__": + unittest.main() diff --git a/msprobe/test/UT/utils_ut/test_hijack.py b/msprobe/test/UT/utils_ut/test_hijack.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0bc86fe075733b7d186d9b9b3dc1128ce47c6f --- /dev/null +++ b/msprobe/test/UT/utils_ut/test_hijack.py @@ -0,0 +1,291 @@ +import sys +import unittest +from unittest.mock import ANY, MagicMock, Mock, patch + +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.hijack import ( + ActionType, + HiJackerManager, + HiJackerPathFinder, + HijackerUnit, + HiJackerWrapperFunction, + HiJackerWrapperModule, + HiJackerWrapperObj, + HijackHandler, + release, +) + + +class TestHijackerUnit(unittest.TestCase): + def test_valid_parameters(self): + stub = MagicMock() + unit = HijackerUnit(stub, "module", "cls", "func", ActionType.REPLACE, 100) + self.assertEqual(unit.stub, stub) + self.assertEqual(unit.module, "module") + self.assertEqual(unit.cls, "cls") + self.assertEqual(unit.function, "func") + self.assertEqual(unit.action, ActionType.REPLACE) + self.assertEqual(unit.priority, 100) + + def test_invalid_stub(self): + with self.assertRaises(MsprobeException): + HijackerUnit("not_callable", "module", "", "", ActionType.REPLACE, 100) + + def test_missing_module(self): + with self.assertRaises(MsprobeException): + HijackerUnit(MagicMock(), "", "", "", ActionType.REPLACE, 100) + + def test_invalid_action(self): + with self.assertRaises(MsprobeException): + HijackerUnit(MagicMock(), "module", "", "", 999, 100) + + def test_replace_module_error(self): + with self.assertRaises(MsprobeException): + HijackerUnit(MagicMock(), "module", "", "", ActionType.REPLACE, 100) + + +class TestHijackerUnit(unittest.TestCase): + + def test_valid_parameters(self): + mock_stub = MagicMock() + unit = HijackerUnit(mock_stub, "module_name", "ClassName", "function_name", ActionType.REPLACE, 1) + self.assertEqual(unit.module, "module_name") + + def test_invalid_stub(self): + with self.assertRaises(MsprobeException) as context: + HijackerUnit("not_callable", "module_name", "ClassName", "function_name", ActionType.REPLACE, 1) + self.assertIn('"stub" should be callable.', str(context.exception)) + + def test_missing_module(self): + mock_stub = MagicMock() + with self.assertRaises(MsprobeException) as context: + HijackerUnit(mock_stub, None, "ClassName", "function_name", ActionType.REPLACE, 1) + self.assertIn('"module" is required.', str(context.exception)) + + def test_invalid_module_type(self): + mock_stub = MagicMock() + with self.assertRaises(MsprobeException) as context: + HijackerUnit(mock_stub, 123, "ClassName", "function_name", ActionType.REPLACE, 1) + self.assertIn('"module" should be a str.', str(context.exception)) + + def test_invalid_cls_type(self): + mock_stub = MagicMock() + with self.assertRaises(MsprobeException) as context: + HijackerUnit(mock_stub, "module_name", 123, "function_name", ActionType.REPLACE, 1) + self.assertIn('"cls" should be a str.', str(context.exception)) + + def test_missing_function_when_cls_present(self): + mock_stub = MagicMock() + with self.assertRaises(MsprobeException) as context: + HijackerUnit(mock_stub, "module_name", "ClassName", None, ActionType.REPLACE, 1) + self.assertIn('"function" should be used when "cls" used.', str(context.exception)) + + def test_invalid_function_type(self): + mock_stub = MagicMock() + with self.assertRaises(MsprobeException) as context: + HijackerUnit(mock_stub, "module_name", "ClassName", 123, ActionType.REPLACE, 1) + self.assertIn('"function" should be a str.', str(context.exception)) + + def test_invalid_action(self): + mock_stub = MagicMock() + with self.assertRaises(MsprobeException) as context: + HijackerUnit(mock_stub, "module_name", "ClassName", "function_name", "INVALID_ACTION", 1) + self.assertIn('"action" should be an ActionType.', str(context.exception)) + + def test_module_replacement_not_supported(self): + mock_stub = MagicMock() + with self.assertRaises(MsprobeException) as context: + HijackerUnit(mock_stub, "module_name", None, None, ActionType.REPLACE, 1) + self.assertIn("replacement of a module is not supported", str(context.exception)) + + def test_invalid_priority_type(self): + mock_stub = MagicMock() + with self.assertRaises(MsprobeException) as context: + HijackerUnit(mock_stub, "module_name", "ClassName", "function_name", ActionType.REPLACE, "high") + self.assertIn("Expected int type, but got str", str(context.exception)) + + +class TestRelease(unittest.TestCase): + @patch("msprobe.utils.hijack.HiJackerManager") + def test_release_valid_handler(self, mock_manager): + handler = MagicMock(spec=HijackHandler) + handler.released = False + handler.unit = "test_unit" + release(handler) + self.assertTrue(handler.released) + mock_manager.remove_unit.assert_called_once_with("test_unit") + + def test_release_with_invalid_handler_type(self): + invalid_handler = "not_a_handler" + with self.assertRaises(MsprobeException) as context: + release(invalid_handler) + self.assertIn("Handler must be an instance of HijackHandler.", str(context.exception)) + + +class TestHijackerManager(unittest.TestCase): + def setUp(self): + HiJackerManager._initialized = False + HiJackerManager._hijacker_units = {} + HiJackerManager._hijacker_wrappers = {} + + @patch("sys.meta_path", []) + def test_initialize(self): + HiJackerManager.initialize() + self.assertTrue(HiJackerManager._initialized) + self.assertIsInstance(sys.meta_path[0], HiJackerPathFinder) + + def test_add_and_remove_unit(self): + stub = MagicMock() + unit = HijackerUnit(stub, "test_module", "", "test_func", ActionType.REPLACE, 100) + handler = HiJackerManager.add_unit(unit) + self.assertIn(handler, HiJackerManager._hijacker_units) + wrapper = HiJackerManager._hijacker_wrappers.get("test_module--test_func") + self.assertIsInstance(wrapper, HiJackerWrapperFunction) + self.assertEqual(len(wrapper.replacement), 1) + HiJackerManager.remove_unit(handler) + self.assertNotIn(handler, HiJackerManager._hijacker_units) + self.assertNotIn("test_module--test_func", HiJackerManager._hijacker_wrappers) + + +class ConcreteHiJackerWrapper(HiJackerWrapperObj): + def activate(self): + pass + + def deactivate(self): + pass + + +class TestRemoveUnit(unittest.TestCase): + def setUp(self): + self.hijacker = ConcreteHiJackerWrapper("mod-class-func") + self.unit_replace = Mock(action=ActionType.REPLACE, priority=1) + self.unit_pre = Mock(action=ActionType.PRE_HOOK, priority=2) + self.unit_post = Mock(action=ActionType.POST_HOOK, priority=3) + + def test_remove_replace_unit(self): + self.hijacker.replacement.append(self.unit_replace) + self.hijacker.remove_unit(self.unit_replace) + self.assertNotIn(self.unit_replace, self.hijacker.replacement) + self.assertEqual(len(self.hijacker.replacement), 0) + + def test_remove_pre_hook_unit(self): + self.hijacker.pre_hooks.append(self.unit_pre) + self.hijacker.remove_unit(self.unit_pre) + self.assertNotIn(self.unit_pre, self.hijacker.pre_hooks) + self.assertEqual(len(self.hijacker.pre_hooks), 0) + + def test_remove_post_hook_unit(self): + self.hijacker.post_hooks.append(self.unit_post) + self.hijacker.remove_unit(self.unit_post) + self.assertNotIn(self.unit_post, self.hijacker.post_hooks) + self.assertEqual(len(self.hijacker.post_hooks), 0) + + def test_remove_non_existent_unit_raises_error(self): + with self.assertRaises(ValueError): + self.hijacker.remove_unit(self.unit_replace) + self.hijacker.pre_hooks.append(Mock(action=ActionType.PRE_HOOK)) + with self.assertRaises(ValueError): + self.hijacker.remove_unit(self.unit_pre) + self.hijacker.post_hooks.append(Mock(action=ActionType.POST_HOOK)) + with self.assertRaises(ValueError): + self.hijacker.remove_unit(self.unit_post) + + def test_remove_from_multiple_units(self): + unit1 = Mock(action=ActionType.REPLACE, priority=1) + unit2 = Mock(action=ActionType.REPLACE, priority=2) + self.hijacker.replacement = [unit1, unit2] + self.hijacker.remove_unit(unit1) + self.assertEqual(self.hijacker.replacement, [unit2]) + + def test_remove_does_not_affect_other_lists(self): + self.hijacker.replacement.append(self.unit_replace) + self.hijacker.pre_hooks.append(self.unit_pre) + self.hijacker.post_hooks.append(self.unit_post) + + self.hijacker.remove_unit(self.unit_replace) + self.assertIn(self.unit_pre, self.hijacker.pre_hooks) + self.assertIn(self.unit_post, self.hijacker.post_hooks) + + +class TestHijackerWrapperModule(unittest.TestCase): + def setUp(self): + self.wrapper = HiJackerWrapperModule("test_module--") + + @patch("msprobe.utils.hijack.HiJackerPathFinder.add_mod") + def test_activate(self, mock_add_mod): + self.wrapper.activate() + mock_add_mod.assert_called_once_with("test_module") + + def test_exec_pre_post_hooks(self): + pre_unit = MagicMock(action=ActionType.PRE_HOOK, stub=MagicMock()) + post_unit = MagicMock(action=ActionType.POST_HOOK, stub=MagicMock()) + self.wrapper.add_unit(pre_unit) + self.wrapper.add_unit(post_unit) + mock_module = MagicMock() + self.wrapper.exec_pre_hook() + pre_unit.stub.assert_called_once() + self.wrapper.exec_post_hook(mock_module) + post_unit.stub.assert_called_once_with(mock_module) + + +class TestHiJackerWrapperFunction(unittest.TestCase): + def setUp(self): + self.target_name = "test_mod-TestClass-test_method" + self.wrapper = HiJackerWrapperFunction(self.target_name) + + self.mock_module = MagicMock() + self.mock_class = MagicMock() + self.original_method = MagicMock() + self.mock_module.TestClass = self.mock_class + self.mock_class.test_method = self.original_method + + def test_initialization(self): + self.assertEqual(self.wrapper.mod_name, "test_mod") + self.assertEqual(self.wrapper.class_name, "TestClass") + self.assertEqual(self.wrapper.func_name, "test_method") + + @patch("msprobe.utils.hijack.hijacker") + @patch.dict("msprobe.utils.hijack.sys.modules", {"test_mod": None}) + def test_activate_module_not_loaded(self, mock_hijacker): + self.wrapper.activate() + mock_hijacker.assert_called_once_with(stub=ANY, module="test_mod", action=ActionType.POST_HOOK, priority=0) + + def test_wrapper_execution_flow(self): + pre_hook = MagicMock() + pre_hook.stub = MagicMock(return_value=(("modified_args",), {"new_kw": 1})) + replacement = MagicMock() + replacement.stub = MagicMock(return_value="replaced_result") + post_hook = MagicMock() + post_hook.stub = MagicMock(return_value="final_result") + + self.wrapper.pre_hooks = [pre_hook] + self.wrapper.replacement = [replacement] + self.wrapper.post_hooks = [post_hook] + self.wrapper.ori_obj = MagicMock() + + result = self.wrapper._get_wrapper()("arg1", kw1=2) + + pre_hook.stub.assert_called_once_with("arg1", kw1=2) + replacement.stub.assert_called_once_with("modified_args", new_kw=1) + post_hook.stub.assert_called_once_with("replaced_result", "modified_args", new_kw=1) + self.assertEqual(result, "final_result") + + def test_pre_hook_type_check(self): + invalid_hook = MagicMock() + invalid_hook.stub = MagicMock(return_value="invalid_type") + self.wrapper.pre_hooks = [invalid_hook] + self.wrapper.ori_obj = MagicMock() + + with self.assertRaises(MsprobeException) as cm: + self.wrapper._get_wrapper()("arg1") + self.assertIn("Pre-hook must return a tuple", str(cm.exception)) + + @patch("msprobe.utils.hijack.release") + @patch.dict("msprobe.utils.hijack.sys.modules", {"test_mod": sys}) + def test_deactivate_with_missing_class(self, mock_release): + self.wrapper.class_name = "NonExistentClass" + self.wrapper.ori_obj = self.original_method + self.wrapper.mod_hijacker = MagicMock() + self.wrapper.deactivate() + self.assertIsNone(self.wrapper.ori_obj) + mock_release.assert_called_once()