diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index c4e426772670212382addb9b855b4bdf69810d3d..1c92c2abeb15c414163b986ca40b20233d4a0982 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -1,4 +1,5 @@ from .debugger.precision_debugger import PrecisionDebugger from .common.utils import seed_all from .compare.distributed_compare import compare_distributed -from .compare.pt_compare import compare \ No newline at end of file +from .compare.pt_compare import compare +from .functional.module_dump import module_dump, module_dump_end diff --git a/debug/accuracy_tools/msprobe/pytorch/functional/data_processor.py b/debug/accuracy_tools/msprobe/pytorch/functional/data_processor.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py deleted file mode 100644 index 5d2e8d9856c59bf715e1e5f7ab01c39dd7de73ed..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch.nn as nn -from msprobe.pytorch.common.log import logger -from msprobe.core.common.const import Const -from msprobe.pytorch.hook_module.api_registry import api_register -from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger -from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.data_dump.scope import BaseScope - -module_count = {} - - -def module_dump(module, dump_name): - if not isinstance(module, nn.Module): - logger.error("The parameter:module in module_dump is not a Module subclass.") - raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) - if not isinstance(dump_name, str): - logger.error("The parameter:dump_name in module_dump is not a str type.") - raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) - api_register.api_originality() - if dump_name not in module_count: - module_count[dump_name] = 0 - else: - module_count[dump_name] += 1 - dump_name = dump_name + Const.SEP + str(module_count.get(dump_name)) + Const.SEP - - pdg = PrecisionDebugger() - _, forward_hook, backward_hook, _ = pdg.service.build_hook(BaseScope.Module_Type_Module, dump_name) - module.register_forward_hook(forward_hook, with_kwargs=True) - module.register_full_backward_hook(backward_hook) - - module.register_forward_pre_hook(pdg.service.module_processor.node_hook(dump_name + Const.FORWARD, Const.START)) - module.register_forward_hook(pdg.service.module_processor.node_hook(dump_name + Const.FORWARD, Const.STOP)) - module.register_full_backward_pre_hook( - pdg.service.module_processor.node_hook(dump_name + Const.BACKWARD, Const.START)) - module.register_full_backward_hook(pdg.service.module_processor.node_hook(dump_name + Const.BACKWARD, Const.STOP)) - - -def module_dump_end(): - api_register.api_modularity() diff --git a/debug/accuracy_tools/msprobe/pytorch/functional/module_dump.py b/debug/accuracy_tools/msprobe/pytorch/functional/module_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..f1eb8ab77527f0a2b27146893121916e81be4d2c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/functional/module_dump.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.data_dump.scope import BaseScope +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger +from msprobe.pytorch.hook_module.api_registry import api_register +from msprobe.pytorch.service import torch_version_above_or_equal_2 + +hook_handle_list = [] + + +def module_dump(module, dump_name): + if not isinstance(module, nn.Module): + logger.error("The parameter module in module_dump must be a Module subclass.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if not isinstance(dump_name, str): + logger.error("The parameter dump_name in module_dump must be a str type.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + api_register.api_originality() + register_hook(module, dump_name) + + +def module_dump_end(): + api_register.api_modularity() + remove_hook() + hook_handle_list.clear() + + +def register_hook(module, dump_name): + prefix = BaseScope.Module_Type_Module + Const.SEP + dump_name + Const.SEP + module.__class__.__name__ + Const.SEP + + pdg = PrecisionDebugger() + _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = \ + pdg.service.build_hook(BaseScope.Module_Type_Module, prefix) + + if torch_version_above_or_equal_2: + forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True) + hook_handle_list.append(forward_hook_handle) + else: + pdg.service.check_register_full_backward_hook(module) + full_backward_hook_handle = module.register_full_backward_hook( + pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2) + hook_handle_list.extend([full_backward_hook_handle, forward_hook_handle]) + pdg.service.check_register_full_backward_hook(module) + full_backward_hook_handle = module.register_full_backward_hook(backward_hook) + + forward_pre_hook_handle = module.register_forward_pre_hook( + pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.START)) + forward_hook_handle = module.register_forward_hook( + pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP)) + hook_handle_list.extend([full_backward_hook_handle, forward_pre_hook_handle, forward_hook_handle]) + + if torch_version_above_or_equal_2: + backward_pre_hook_handle = module.register_full_backward_pre_hook( + pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.START)) + pdg.service.check_register_full_backward_hook(module) + full_backward_hook_handle = module.register_full_backward_hook( + pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + hook_handle_list.extend([backward_pre_hook_handle, full_backward_hook_handle]) + + +def remove_hook(): + for hook_handle in hook_handle_list: + if isinstance(hook_handle, torch.utils.hooks.RemovableHandle): + hook_handle.remove() diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 1478b2c84996223f933b3598ff315fa34dab3cd1..01502067232e57e696c049b84caf77c3ce963ea2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -51,14 +51,13 @@ class Service: module._is_full_backward_hook is False: return True return False - - def check_register_full_backward_hook(self, module, backward_hook): + + def check_register_full_backward_hook(self, module): if self.is_registered_backward_hook(module): module._backward_hooks.clear() module._is_full_backward_hook = None - logger.warning("Found regular backward hooks. Removing them and switching to full backward hooks.") - module.register_full_backward_hook(backward_hook) - + logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.") + def build_hook(self, module_type, name): def pre_hook(api_or_module_name, module, args, kwargs): if not self.should_execute_hook(): @@ -239,10 +238,12 @@ class Service: if torch_version_above_or_equal_2: module.register_forward_hook(forward_hook, with_kwargs=True) else: - self.check_register_full_backward_hook(module, - self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + self.check_register_full_backward_hook(module) + module.register_full_backward_hook( + self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) module.register_forward_hook(forward_hook_torch_version_below_2) - self.check_register_full_backward_hook(module, backward_hook) + self.check_register_full_backward_hook(module) + module.register_full_backward_hook(backward_hook) module.register_forward_pre_hook( self.module_processor.node_hook(prefix + Const.FORWARD, Const.START)) @@ -251,7 +252,9 @@ class Service: if torch_version_above_or_equal_2: module.register_full_backward_pre_hook( self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START)) - self.check_register_full_backward_hook(module, self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + self.check_register_full_backward_hook(module) + module.register_full_backward_hook( + self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) if self.config.level in ["mix", "L1", "L2"]: api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API), diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py deleted file mode 100644 index d67adf2f91292391ff01d450bb5647524f6fc9c4..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py +++ /dev/null @@ -1,15 +0,0 @@ -import unittest - -import torch.nn as nn -from msprobe.pytorch import PrecisionDebugger -from msprobe.pytorch.functional.dump_module import module_dump, module_count - - -class TestDumpModule(unittest.TestCase): - def setUp(self): - self.module = nn.Linear(in_features=8, out_features=4) - - def test_module_dump(self): - PrecisionDebugger(dump_path="./dump") - module_dump(self.module, "TestModule") - self.assertTrue("TestModule" in module_count) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_module_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e685e39c13f080b05552493135246c48c1dd94 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_module_dump.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock + +import torch +import torch.nn as nn +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.log import logger +from msprobe.pytorch import PrecisionDebugger +from msprobe.pytorch.service import torch_version_above_or_equal_2 +from msprobe.pytorch.functional.module_dump import module_dump, module_dump_end, \ + hook_handle_list, remove_hook, register_hook + + +class TestModuleDump(unittest.TestCase): + def setUp(self): + self.module = nn.Linear(8, 4) + + def tearDown(self): + hook_handle_list.clear() + + # @patch.object(logger, 'error') + # def test_module_dump(self, mock_error): + # with self.assertRaises(MsprobeException) as context: + # module_dump(1, "TestModule") + # self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + # mock_error.assert_called_with("The parameter module in module_dump must be a Module subclass.") + # + # with self.assertRaises(MsprobeException) as context: + # module_dump(self.module, 1) + # self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + # mock_error.assert_called_with("The parameter dump_name in module_dump must be a str type.") + # + # with patch('msprobe.pytorch.functional.module_dump.register_hook') as mock_register_hook: + # module_dump(self.module, "TestModule") + # mock_register_hook.assert_called_with(self.module, "TestModule") + # + # def test_module_dump_end(self): + # hook_handle_list.extend([1, 2, 3]) + # with patch('msprobe.pytorch.functional.module_dump.remove_hook') as mock_remove_hook: + # module_dump_end() + # mock_remove_hook.assert_called_once() + # self.assertEqual(hook_handle_list, []) + + def test_register_hook(self): + PrecisionDebugger(dump_path="./") + register_hook(self.module, "TestModule") + if torch_version_above_or_equal_2: + self.assertEqual(len(hook_handle_list), 6) + else: + self.assertEqual(len(hook_handle_list), 5) + + def test_remove_hook(self): + mock_handle = MagicMock(spec=torch.utils.hooks.RemovableHandle) + hook_handle_list.append(mock_handle) + remove_hook() + + mock_handle.remove.assert_called_once()