From db7f723a531d3eeea7d5a18c485a0dc508eea54f Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Tue, 17 Jun 2025 16:18:08 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=AE=89=E5=85=A8=E3=80=91=E4=B8=8D?= =?UTF-8?q?=E5=AF=B9=E5=A4=96=E6=8F=90=E4=BE=9B=E6=8E=A5=E5=8F=A3=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E4=B8=8B=E5=88=92=E7=BA=BF=E5=91=BD=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/debugger/precision_debugger.py | 29 ++-- .../mindspore/debugger/precision_debugger.py | 6 +- .../pytorch/debugger/precision_debugger.py | 47 +++--- .../debugger/test_ms_precision_debugger.py | 6 +- .../test/mindspore_ut/test_ms_debug_save.py | 2 +- .../debugger/test_pt_debugger_start.py | 2 +- .../debugger/test_pt_precision_debugger.py | 148 ++++++++++-------- .../test/pytorch_ut/test_pt_debug_save.py | 2 +- 8 files changed, 134 insertions(+), 108 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py index 39016ea5d..03698530b 100644 --- a/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os from msprobe.core.common.const import Const, FileCheckConst, MsgConst @@ -46,14 +47,14 @@ class BasePrecisionDebugger: if self.initialized: return self.initialized = True - self.check_input_params(config_path, task, dump_path, level) - self.common_config, self.task_config = self.parse_config_path(config_path, task) + self._check_input_params(config_path, task, dump_path, level) + self.common_config, self.task_config = self._parse_config_path(config_path, task) self.task = self.common_config.task if step is not None: self.common_config.step = get_real_step_or_rank(step, Const.STEP) @staticmethod - def check_input_params(config_path, task, dump_path, level): + def _check_input_params(config_path, task, dump_path, level): if not config_path: config_path = os.path.join(os.path.dirname(__file__), "../../config.json") @@ -80,16 +81,7 @@ class BasePrecisionDebugger: @staticmethod def _get_task_config(task, json_config): - raise NotImplementedError("Subclass must implment _get_task_config") - - @classmethod - def get_instance(cls): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if instance.task in BasePrecisionDebugger.tasks_not_need_debugger: - instance = None - return instance + raise NotImplementedError("Subclass must implement _get_task_config") @classmethod def forward_backward_dump_end(cls): @@ -130,7 +122,16 @@ class BasePrecisionDebugger: raise Exception(MsgConst.NOT_CREATED_INSTANCE) instance.service.restore_custom_api(module, api) - def parse_config_path(self, json_file_path, task): + @classmethod + def _get_instance(cls): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.task in BasePrecisionDebugger.tasks_not_need_debugger: + instance = None + return instance + + def _parse_config_path(self, json_file_path, task): if not json_file_path: json_file_path = os.path.join(os.path.dirname(__file__), "../../config.json") diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index d122b6109..efc9b39b6 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -129,7 +129,7 @@ class PrecisionDebugger(BasePrecisionDebugger): @classmethod def start(cls, model=None, token_range=None): - instance = cls.get_instance() + instance = cls._get_instance() if instance is None: return if cls._need_msprobe_c() and _msprobe_c: @@ -158,7 +158,7 @@ class PrecisionDebugger(BasePrecisionDebugger): @classmethod def stop(cls): - instance = cls.get_instance() + instance = cls._get_instance() if instance is None: return @@ -174,7 +174,7 @@ class PrecisionDebugger(BasePrecisionDebugger): @classmethod def step(cls): - instance = cls.get_instance() + instance = cls._get_instance() if instance is None: return diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index fd39a9a18..0aa3f0c96 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -55,17 +55,34 @@ class PrecisionDebugger(BasePrecisionDebugger): self.enable_dataloader = self.config.enable_dataloader self._param_warning() - @property - def instance(self): - return self._instance - @staticmethod def _get_task_config(task, json_config): return parse_task_config(task, json_config) + @staticmethod + def _iter_tracer(func): + def func_wrapper(*args, **kwargs): + debugger_instance = PrecisionDebugger._instance + if not debugger_instance: + raise MsprobeException( + MsprobeException.INTERFACE_USAGE_ERROR, + f"PrecisionDebugger must be instantiated before executing the dataloader iteration" + ) + + debugger_instance.enable_dataloader = False + if not debugger_instance.service.first_start: + debugger_instance.stop() + debugger_instance.step() + result = func(*args, **kwargs) + debugger_instance.start() + debugger_instance.enable_dataloader = True + return result + + return func_wrapper + @classmethod def start(cls, model=None, token_range=None): - instance = cls.get_instance() + instance = cls._get_instance() if instance is None: return @@ -79,7 +96,7 @@ class PrecisionDebugger(BasePrecisionDebugger): @classmethod def stop(cls): - instance = cls.get_instance() + instance = cls._get_instance() if instance is None: return if instance.enable_dataloader: @@ -89,7 +106,7 @@ class PrecisionDebugger(BasePrecisionDebugger): @classmethod def step(cls): - instance = cls.get_instance() + instance = cls._get_instance() if instance is None: return cls._instance.service.step() @@ -123,7 +140,7 @@ class PrecisionDebugger(BasePrecisionDebugger): ) if self.enable_dataloader: logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") - dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) + dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__) def module_dump(module, dump_name): @@ -155,17 +172,3 @@ def module_dump_end(): f"PrecisionDebugger must be instantiated before using module_dump_end interface" ) instance.module_dumper.stop_module_dump() - - -def iter_tracer(func): - def func_wrapper(*args, **kwargs): - debugger_instance = PrecisionDebugger.instance - debugger_instance.enable_dataloader = False - if not debugger_instance.service.first_start: - debugger_instance.stop() - debugger_instance.step() - result = func(*args, **kwargs) - debugger_instance.start() - debugger_instance.enable_dataloader = True - return result - return func_wrapper diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py index f4be0fc63..91c712d6c 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py @@ -55,7 +55,7 @@ class TestPrecisionDebugger(unittest.TestCase): mock_get_mode = MagicMock() mock_parse_json_config = MagicMock() - with patch.object(BasePrecisionDebugger, "parse_config_path", new=mock_parse_json_config), \ + with patch.object(BasePrecisionDebugger, "_parse_config_path", new=mock_parse_json_config), \ patch.object(PrecisionDebugger, "_get_execution_mode", new=mock_get_mode), \ patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): @@ -83,7 +83,7 @@ class TestPrecisionDebugger(unittest.TestCase): debugger.start() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - with patch.object(BasePrecisionDebugger, "parse_config_path", new=mock_parse_json_config), \ + with patch.object(BasePrecisionDebugger, "_parse_config_path", new=mock_parse_json_config), \ patch.object(PrecisionDebugger, "_get_execution_mode", new=mock_get_mode), \ patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): @@ -139,7 +139,7 @@ class TestPrecisionDebugger(unittest.TestCase): common_config = CommonConfig(json_config) task_config = StatisticsConfig(json_config) - with patch.object(BasePrecisionDebugger, "parse_config_path", return_value=(common_config, task_config)), \ + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): debugger = PrecisionDebugger() debugger.task = "statistics" diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py index ca79a50d3..2ae1d9250 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py @@ -39,7 +39,7 @@ class TestMindsporeDebuggerSave(TestCase): } common_config = CommonConfig(statistics_task_json) task_config = StatisticsConfig(statistics_task_json) - with patch.object(BasePrecisionDebugger, "parse_config_path", return_value=(common_config, task_config)), \ + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): self.debugger = PrecisionDebugger() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py index 270a2dcd3..9b6622831 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py @@ -50,7 +50,7 @@ class MultiStartDebugger: common_config = CommonConfig(json_config) task_config = StatisticsConfig(json_config) - with patch.object(BasePrecisionDebugger, "parse_config_path", return_value=(common_config, task_config)): + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)): cls.debugger = PrecisionDebugger(task="statistics", level="L0", dump_path=dump_path) @classmethod diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py index 7228c3990..249432717 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py @@ -8,7 +8,7 @@ import torch from msprobe.core.common.const import Const, MsgConst from msprobe.core.common.utils import get_real_step_or_rank from msprobe.core.common.exceptions import MsprobeException, FileCheckException -from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger, iter_tracer +from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor from msprobe.test.pytorch_ut.grad_probe.test_grad_monitor import common_config, task_config from msprobe.core.common_config import CommonConfig @@ -56,43 +56,43 @@ class TestPrecisionDebugger(unittest.TestCase): step = get_real_step_or_rank([0, 1, "3-5"], Const.STEP) self.assertListEqual(step, [0, 1, 3, 4, 5]) - def test_instance(self): - debugger1 = PrecisionDebugger(dump_path="./dump_path") - debugger2 = PrecisionDebugger(dump_path="./dump_path") - self.assertIs(debugger1.instance, debugger2.instance) - def test_check_input_params(self): - args = Args(config_path = 1) + args = Args(config_path=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args.config_path, args.task, args.dump_path, args.level) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(config_path = "./") + args = Args(config_path="./") with self.assertRaises(FileCheckException) as context: - PrecisionDebugger.check_input_params(args.config_path, args.task, args.dump_path, args.level) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, FileCheckException.INVALID_FILE_ERROR) - args = Args(task = 1) + args = Args(task=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args.config_path, args.task, args.dump_path, args.level) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(dump_path = 1) + args = Args(dump_path=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args.config_path, args.task, args.dump_path, args.level) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(level = 1) + args = Args(level=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args.config_path, args.task, args.dump_path, args.level) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(config_path = os.path.join(os.path.dirname(__file__), "../../../config.json"), - task = Const.TASK_LIST[0], - dump_path="./dump_path", - level = Const.LEVEL_LIST[0], - model = torch.nn.Module()) - checked_input_params = PrecisionDebugger.check_input_params(args.config_path, args.task, args.dump_path, args.level) + args = Args(config_path=os.path.join(os.path.dirname(__file__), "../../../config.json"), + task=Const.TASK_LIST[0], + dump_path="./dump_path", + level=Const.LEVEL_LIST[0], + model=torch.nn.Module()) + checked_input_params = PrecisionDebugger._check_input_params( + args.config_path, + args.task, + args.dump_path, + args.level + ) self.assertIsNone(checked_input_params) def test_start_grad_probe(self): @@ -101,14 +101,14 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.start() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - with patch.object(BasePrecisionDebugger, "parse_config_path", + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(self.grad_common_config, self.grad_task_config)): PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_start = PrecisionDebugger.start() self.assertIsNone(checked_start) def test_start_statistics(self): - with patch.object(BasePrecisionDebugger, "parse_config_path", + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(self.statistics_common_config, self.statistics_task_config)): debugger = PrecisionDebugger(dump_path="./dump_path") debugger.service = MagicMock() @@ -118,8 +118,11 @@ class TestPrecisionDebugger(unittest.TestCase): debugger.service.start.assert_called_once() def test_forward_backward_dump_end(self): - with patch.object(BasePrecisionDebugger, "parse_config_path", return_value=(self.statistics_common_config, - self.statistics_task_config)): + with patch.object( + BasePrecisionDebugger, + "_parse_config_path", + return_value=(self.statistics_common_config,self.statistics_task_config) + ): debugger = PrecisionDebugger(dump_path="./dump_path", task='statistics') debugger.service = MagicMock() debugger.config = MagicMock() @@ -133,7 +136,7 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.stop() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - with patch.object(BasePrecisionDebugger, "parse_config_path", + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(self.grad_common_config, self.grad_task_config)): PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_stop = PrecisionDebugger.stop() @@ -151,9 +154,8 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger._instance = None PrecisionDebugger.step() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - with patch.object(BasePrecisionDebugger, "parse_config_path", + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(self.grad_common_config, self.grad_task_config)): - PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_step = PrecisionDebugger.step() self.assertIsNone(checked_step) @@ -171,8 +173,11 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.monitor(torch.nn.Module()) self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - with patch.object(BasePrecisionDebugger, "parse_config_path", return_value=(self.statistics_common_config, - self.statistics_task_config)): + with patch.object( + BasePrecisionDebugger, + "_parse_config_path", + return_value=(self.statistics_common_config, self.statistics_task_config) + ): debugger = PrecisionDebugger(task=Const.STATISTICS, dump_path="./dump_path") checked_monitor = debugger.monitor(torch.nn.Module()) self.assertIsNone(checked_monitor) @@ -184,40 +189,57 @@ class TestPrecisionDebugger(unittest.TestCase): debugger.gm.monitor(torch.nn.Module()) debugger.gm.monitor.assert_called_once() - @patch('msprobe.pytorch.debugger.precision_debugger.PrecisionDebugger') - def test_iter_tracer(self, mock_debugger): - mock_debugger_instance = mock_debugger.instance = MagicMock() - mock_debugger_instance.service.first_start = False - - @iter_tracer - def dataloader_func(): - return "test_iter_tracer" - result = dataloader_func() - self.assertEqual(result, "test_iter_tracer") - - mock_debugger_instance.stop.assert_called_once() - mock_debugger_instance.step.assert_called_once() - mock_debugger_instance.start.assert_called_once() - self.assertTrue(mock_debugger_instance.enable_dataloader) - - @patch('msprobe.pytorch.debugger.precision_debugger.PrecisionDebugger') - def test_iter_tracer_first_start(self, mock_debugger): - mock_debugger_instance = mock_debugger.instance = MagicMock() - mock_debugger_instance.service.first_start = True - - @iter_tracer - def dataloader_func(): - return "test_iter_tracer" - result = dataloader_func() - self.assertEqual(result, "test_iter_tracer") - - mock_debugger_instance.stop.assert_not_called() - mock_debugger_instance.step.assert_not_called() - mock_debugger_instance.start.assert_called_once() - self.assertTrue(mock_debugger_instance.enable_dataloader) - def tearDown(self): if os.path.exists("./dump_path/"): shutil.rmtree("./dump_path/") if os.path.exists("./grad_output/"): shutil.rmtree("./grad_output/") + + +class TestIterTracer(unittest.TestCase): + def setUp(self): + self.debugger = MagicMock() + self.debugger.service.first_start = False + self.debugger.enable_dataloader = True + self.ori_instance = PrecisionDebugger._instance + PrecisionDebugger._instance = self.debugger + + def tearDown(self): + PrecisionDebugger._instance = self.ori_instance + + def test_debugger_with_not_first_start(self): + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 1" + + result = test_func() + + self.assertEqual(result, "test case 1") + self.debugger.stop.assert_called_once() + self.debugger.step.assert_called_once() + self.debugger.start.assert_called_once() + + def test_debugger_with_first_start(self): + self.debugger.service.first_start = True + + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 2" + + result = test_func() + self.assertEqual(result, "test case 2") + self.debugger.stop.assert_not_called() + self.debugger.step.assert_not_called() + self.debugger.start.assert_called_once() + + def test_no_debugger_instance(self): + PrecisionDebugger._instance = None + + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 3" + + with self.assertRaises(MsprobeException) as context: + result = test_func() + self.assertEqual(result, "test case 3") + self.assertEqual(context.exception.code, MsprobeException.INTERFACE_USAGE_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py index 52b6aa623..e517e1cef 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py @@ -37,7 +37,7 @@ class TestPytorchDebuggerSave(TestCase): } common_config = CommonConfig(statistics_task_json) task_config = BaseConfig(statistics_task_json) - with patch.object(BasePrecisionDebugger, "parse_config_path", return_value=(common_config, task_config)): + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)): self.debugger = PrecisionDebugger() def test_forward_and_backward(self): -- Gitee