From 9c0bca408947b190204074602e20d8496e859aa1 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 Aug 2024 15:09:46 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=8C=87=E5=AE=9Astep?= =?UTF-8?q?=E9=9C=80=E6=B1=82=E8=B5=84=E6=96=99=EF=BC=8C=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E5=8D=95=EF=BC=8CUT?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/data_dump/data_processor/base.py | 10 +++++++++ .../data_processor/mindspore_processor.py | 8 +++---- .../mindspore/debugger/debugger_config.py | 4 ++-- .../msprobe/mindspore/doc/dump.md | 17 ++++++++++++++ .../dump/hook_cell/support_wrap_ops.yaml | 3 --- .../test/mindspore_ut/test_primitive_dump.py | 22 +++++++++++++++++++ 6 files changed, 55 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index b0ebe5da9d5..f595aeddabc 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -110,6 +110,16 @@ class BaseDataProcessor: stack_info_struct = {name: stack_str} return stack_info_struct + @staticmethod + def transfer_type(data): + dtype = str(type(data)) + if 'int' in dtype: + return int(data) + elif 'float' in dtype: + return float(data) + else: + return data + @staticmethod def _convert_numpy_to_builtin(arg): type_mapping = { diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 2abb294f60a..45e09a4513f 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -116,10 +116,10 @@ class MindsporeDataProcessor(BaseDataProcessor): 'type': 'mindspore.Tensor', 'dtype': str(tensor.dtype), 'shape': tensor.shape, - 'Max': tensor_stat.max, - 'Min': tensor_stat.min, - 'Mean': tensor_stat.mean, - 'Norm': tensor_stat.norm + 'Max': self.transfer_type(tensor_stat.max), + 'Min': self.transfer_type(tensor_stat.min), + 'Mean': self.transfer_type(tensor_stat.mean), + 'Norm': self.transfer_type(tensor_stat.norm), } if self.config.summary_mode == Const.MD5: tensor_md5 = self.get_md5_for_tensor(tensor) diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py index 78dc253fa87..2858f8f01be 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py @@ -64,8 +64,8 @@ class DebuggerConfig: def _check_step(self): for s in self.step: - if not isinstance(s, int): - raise ValueError(f"step element {s} should be int") + if not isinstance(s, int) or s < 0: + raise ValueError(f"step element {s} should be a positive integer.") def _make_dump_path_if_not_exists(self): check_path_before_create(self.dump_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/doc/dump.md b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md index 0d45e1b149d..c9b8c38f297 100644 --- a/debug/accuracy_tools/msprobe/mindspore/doc/dump.md +++ b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md @@ -101,6 +101,8 @@ debugger.start() ### MindSpore动态图场景 +当使用模型使用for循环时,在每个迭代的开始插入debugger.start(),在每个迭代的结束插入debugger.stop()与debugger.step(): + ```Python import mindspore as ms from msprobe.mindspore import PrecisionDebugger @@ -122,6 +124,21 @@ for data, label in data_loader: debugger.step() # 结束一个step的dump ``` +当使用模型的train方法而非for循环时,可以通过在callbacks参数中传入MsprobeStep(debugger): + +```Python +from msprobe.mindspore.common.utils import MsprobeStep +from msprobe.mindspore import PrecisionDebugger + +# 初始化PrecisionDebugger +debugger = PrecisionDebugger(config_path="./config.json") + +# 自动在每个step开始时调用start(),在每个step结束时调用stop()和step()。 +# 这意味着您无需手动在循环内添加start、stop和step函数,框架会自动完成数据的dump操作。 +trainer.train(1, dataset_train, callbacks=[loss_monior, MsprobeStep(debugger)]) + +``` + ## dump结果文件介绍 ### MindSpore静态图场景 diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 089f444b618..ca2a578f9b3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -347,7 +347,6 @@ ops: - linspace - logspace - one_hot - - arange - range - heaviside - bernoulli @@ -726,7 +725,6 @@ tensor: - trace - swapaxes - tile - - to - topk - tril - tensor_split @@ -760,7 +758,6 @@ mint.ops: - all - any - any_ex - - arange - argmax - avg_pool2d - baddbmm diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index b85fbacd482..aa6372fa385 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -23,6 +23,8 @@ from msprobe.mindspore.service import Service from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell +from collections import defaultdict class DummyModel(nn.Cell): @@ -76,3 +78,23 @@ class TestService(unittest.TestCase): self.assertEqual(self.service.primitive_counters[primitive_name], 0) self.service.update_primitive_counters(primitive_name) self.assertEqual(self.service.primitive_counters[primitive_name], 1) + + def test_step_updates_iteration(self): + initial_iter = self.service.current_iter + self.service.step() + self.assertEqual(self.service.current_iter, initial_iter + 1) + + @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) + def test_step_resets_counters(self, _): + # 假设在 step 调用之前已经有一些 primitive_counters + self.service.primitive_counters["test_primitive"] = 5 + self.service.step() + self.assertEqual(self.service.primitive_counters, {}) + self.assertEqual(HOOKCell.cell_count, defaultdict(int)) + + def test_step_calls_update_iter(self): + # 检查是否在调用 step 时调用了 update_iter + with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: + initial_iter = self.service.current_iter + self.service.step() + mock_update_iter.assert_called_once_with(initial_iter + 1) \ No newline at end of file -- Gitee