diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index b0ebe5da9d59dc06beb1be89cccd7bf665102ed7..f595aeddabc6a301c8d2f62dcb1173645d5c69e2 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -110,6 +110,16 @@ class BaseDataProcessor: stack_info_struct = {name: stack_str} return stack_info_struct + @staticmethod + def transfer_type(data): + dtype = str(type(data)) + if 'int' in dtype: + return int(data) + elif 'float' in dtype: + return float(data) + else: + return data + @staticmethod def _convert_numpy_to_builtin(arg): type_mapping = { diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 2abb294f60adc35f8844c8904944e25b09046b73..45e09a4513fd358dd56e6bdf9672c49425a022e1 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -116,10 +116,10 @@ class MindsporeDataProcessor(BaseDataProcessor): 'type': 'mindspore.Tensor', 'dtype': str(tensor.dtype), 'shape': tensor.shape, - 'Max': tensor_stat.max, - 'Min': tensor_stat.min, - 'Mean': tensor_stat.mean, - 'Norm': tensor_stat.norm + 'Max': self.transfer_type(tensor_stat.max), + 'Min': self.transfer_type(tensor_stat.min), + 'Mean': self.transfer_type(tensor_stat.mean), + 'Norm': self.transfer_type(tensor_stat.norm), } if self.config.summary_mode == Const.MD5: tensor_md5 = self.get_md5_for_tensor(tensor) diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py index 78dc253fa87e9d4525d0bec4438f1f511e39575a..2858f8f01be5c3dd66fa8d4319871b610bb89bbd 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py @@ -64,8 +64,8 @@ class DebuggerConfig: def _check_step(self): for s in self.step: - if not isinstance(s, int): - raise ValueError(f"step element {s} should be int") + if not isinstance(s, int) or s < 0: + raise ValueError(f"step element {s} should be a positive integer.") def _make_dump_path_if_not_exists(self): check_path_before_create(self.dump_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/doc/dump.md b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md index 0d45e1b149d409c92b732288cbc7be34235562fc..c9b8c38f2972812da1d738279139a7b640ee767e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/doc/dump.md +++ b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md @@ -101,6 +101,8 @@ debugger.start() ### MindSpore动态图场景 +当使用模型使用for循环时,在每个迭代的开始插入debugger.start(),在每个迭代的结束插入debugger.stop()与debugger.step(): + ```Python import mindspore as ms from msprobe.mindspore import PrecisionDebugger @@ -122,6 +124,21 @@ for data, label in data_loader: debugger.step() # 结束一个step的dump ``` +当使用模型的train方法而非for循环时,可以通过在callbacks参数中传入MsprobeStep(debugger): + +```Python +from msprobe.mindspore.common.utils import MsprobeStep +from msprobe.mindspore import PrecisionDebugger + +# 初始化PrecisionDebugger +debugger = PrecisionDebugger(config_path="./config.json") + +# 自动在每个step开始时调用start(),在每个step结束时调用stop()和step()。 +# 这意味着您无需手动在循环内添加start、stop和step函数,框架会自动完成数据的dump操作。 +trainer.train(1, dataset_train, callbacks=[loss_monior, MsprobeStep(debugger)]) + +``` + ## dump结果文件介绍 ### MindSpore静态图场景 diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 089f444b6181f0623c8029926c4808ab22ae27ca..ca2a578f9b3230242aace90e3429c597d3daa671 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -347,7 +347,6 @@ ops: - linspace - logspace - one_hot - - arange - range - heaviside - bernoulli @@ -726,7 +725,6 @@ tensor: - trace - swapaxes - tile - - to - topk - tril - tensor_split @@ -760,7 +758,6 @@ mint.ops: - all - any - any_ex - - arange - argmax - avg_pool2d - baddbmm diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index b85fbacd482e38775e67fd9cea89e334e74ad433..aa6372fa385c410d8395c01fe53f859e6d4749e7 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -23,6 +23,8 @@ from msprobe.mindspore.service import Service from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell +from collections import defaultdict class DummyModel(nn.Cell): @@ -76,3 +78,23 @@ class TestService(unittest.TestCase): self.assertEqual(self.service.primitive_counters[primitive_name], 0) self.service.update_primitive_counters(primitive_name) self.assertEqual(self.service.primitive_counters[primitive_name], 1) + + def test_step_updates_iteration(self): + initial_iter = self.service.current_iter + self.service.step() + self.assertEqual(self.service.current_iter, initial_iter + 1) + + @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) + def test_step_resets_counters(self, _): + # 假设在 step 调用之前已经有一些 primitive_counters + self.service.primitive_counters["test_primitive"] = 5 + self.service.step() + self.assertEqual(self.service.primitive_counters, {}) + self.assertEqual(HOOKCell.cell_count, defaultdict(int)) + + def test_step_calls_update_iter(self): + # 检查是否在调用 step 时调用了 update_iter + with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: + initial_iter = self.service.current_iter + self.service.step() + mock_update_iter.assert_called_once_with(initial_iter + 1) \ No newline at end of file