diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 8d55dd4c1647e3584855491fa67ba3b9741ffa23..717f04b9962b147f7a4ac2569bf6f2c47f014dff 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -19,6 +19,7 @@ import re import subprocess import time from datetime import datetime, timezone +from dataclasses import is_dataclass import numpy as np @@ -491,4 +492,22 @@ class DumpPathAggregation: construct_file_path = None dump_tensor_data_dir = None free_benchmark_file_path = None - debug_file_path = None \ No newline at end of file + debug_file_path = None + + +def is_save_variable_valid(variable, valid_special_types, depth=0): + if depth > Const.DUMP_MAX_DEPTH: + return False + if isinstance(variable, valid_special_types): + return True + elif is_dataclass(variable): + return True + elif isinstance(variable, (list, tuple)): + return all(is_save_variable_valid(item, valid_special_types, depth + 1) for item in variable) + elif isinstance(variable, dict): + return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1) + for key, value in variable.items()) + elif variable is None: + return True + else: + return False \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md index 21cd87c0146047a9d40f384e9c2e7b46ce39fe71..b82c3c20a6a9003f4ac3db0441c43a20e07aaff2 100644 --- a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md @@ -185,7 +185,7 @@ save(variable, name, save_backward=True) **参数说明**: | 参数名称 | 参数含义 | 支持数据类型 | 是否必选| | ---------- | ------------------| ------------------- | ------------------- | -| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | +| variable | 需要保存的变量 |dict, list, tuple, np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray, bool, int, float, str, slice, type(Ellipsis), torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, torch.distributed.ProcessGroup, torch.distributed.P2POp, torch.distributed.ReduceOp | 是 | | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index f8670c93c308b76bb2f177a3342d1a85f8e868fb..c443932d431d4969ee84e2864895b1636aa3f722 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -144,7 +144,7 @@ save(variable, name, save_backward=True) **参数说明**: | 参数名称 | 参数含义 | 支持数据类型 | 是否必选| | ---------- | ------------------| ------------------- | ------------------- | -| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | +| variable | 需要保存的变量 |dict, list, tuple, np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray, bool, int, float, str, slice, type(Ellipsis), mindspore.Tensor, mindspore._c_expression.typing.Number | 是 | | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | diff --git a/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md b/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md index 6f4d519d5f61d5efaaffe54a1bde4f140b539f72..74b211649a391198ff3a1b7c9462829882dab924 100644 --- a/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md +++ b/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md @@ -52,8 +52,8 @@ L0, L1, mix dump存在盲区,网络中的非api/module的输入输出不会被 初始化 ```python -# 训练启动py脚本 -from mindspore.pytorch import PrecisionDebugger +# 训练启动py脚本 以pytorch场景为例 +from msprobe.pytorch import PrecisionDebugger debugger = PrecisionDebugger("./config.json") for data, label in data_loader: # 执行模型训练 @@ -63,8 +63,8 @@ for data, label in data_loader: 初始化(无配置文件) ```python -# 训练启动py脚本 -from mindspore.pytorch import PrecisionDebugger +# 训练启动py脚本 以pytorch场景为例 +from msprobe.pytorch import PrecisionDebugger debugger = PrecisionDebugger(dump_path="dump_path", level="debug") for data, label in data_loader: # 执行模型训练 @@ -74,8 +74,8 @@ for data, label in data_loader: 调用保存接口 ```python -# 训练过程中被调用py文件 -from mindspore.pytorch import PrecisionDebugger +# 训练过程中被调用py文件 以pytorch场景为例 +from msprobe.pytorch import PrecisionDebugger dict_variable = {"key1": "value1", "key2": [1, 2]} PrecisionDebugger.save(dict_variable, "dict_variable", save_backward=False) diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 625842da589a3090cddc75c50175ac577f1777b6..7f15cdafea490e48ad4e576d05c79e8ea47c22a3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -25,7 +25,7 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy from msprobe.core.common.log import logger from msprobe.core.common.const import Const -from msprobe.core.common.utils import CompareException, check_seed_all +from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid class MsprobeStep(ms.train.Callback): @@ -192,9 +192,12 @@ def set_register_backward_hook_functions(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call - if not isinstance(variable, (list, dict, tuple, ms.Tensor, int, float, str)): + from msprobe.core.data_dump.data_processor.mindspore_processor import MindsporeDataProcessor + valid_data_types = MindsporeDataProcessor.get_special_types() + if not is_save_variable_valid(variable, valid_data_types): + valid_data_types_with_nested_types = valid_data_types + ("dataclass", dict, tuple, list) logger.warning("PrecisionDebugger.save variable type not valid, " - "should be one of list, dict, tuple, ms.Tensor, int, float or string. " + f"should be one of {valid_data_types_with_nested_types}" "Skip current save process.") raise ValueError if not isinstance(name, str): diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index a7c2e0569cd745a4b7d28d7ead189394fcbd887d..8b0b18f679b8f63d1bfbd5ff8e3b980e207a340f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -28,7 +28,7 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import (FileCheckConst, change_mode, check_file_or_directory_path, check_path_before_create, FileOpen) from msprobe.core.common.log import logger -from msprobe.core.common.utils import check_seed_all +from msprobe.core.common.utils import check_seed_all, is_save_variable_valid from packaging import version try: @@ -457,9 +457,12 @@ def is_recomputation(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call - if not isinstance(variable, (list, dict, tuple, torch.Tensor, int, float, str)): + from msprobe.core.data_dump.data_processor.pytorch_processor import PytorchDataProcessor + valid_data_types = PytorchDataProcessor.get_special_types() + if not is_save_variable_valid(variable, valid_data_types): + valid_data_types_with_nested_types = valid_data_types + ("dataclass", dict, tuple, list) logger.warning("PrecisionDebugger.save variable type not valid, " - "should be one of list, dict, tuple, torch.Tensor, int, float or string. " + f"should be one of {valid_data_types_with_nested_types}" "Skip current save process.") raise ValueError if not isinstance(name, str): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index df5d6c7857ec69d0ce71e314e8988cc2ac68eb8b..0efd470d6dbbcd9e9b00402e4542ef888a77d831 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -25,6 +25,7 @@ from unittest.mock import MagicMock, mock_open, patch import OpenSSL import numpy as np from pathlib import Path +from dataclasses import dataclass from msprobe.core.common.const import Const from msprobe.core.common.file_utils import ( @@ -47,15 +48,16 @@ from msprobe.core.common.utils import (CompareException, check_regex_prefix_format_valid, set_dump_path, get_dump_mode, - get_real_step_or_rank, - get_step_or_rank_from_string, + get_real_step_or_rank, + get_step_or_rank_from_string, get_stack_construct_by_dump_json_path, check_seed_all, safe_get_value, MsprobeBaseException, check_str_param, is_json_file, - detect_framework_by_dump_json) + detect_framework_by_dump_json, + is_save_variable_valid) from msprobe.core.common.decorator import recursion_depth_decorator @@ -337,7 +339,7 @@ class TestUtils(TestCase): def test_recursion_depth_decorator(self, mock_error): # 测试递归深度限制函数 recursion_list = [[]] - temp_list = recursion_list[0] + temp_list = recursion_list[0] for _ in range(Const.MAX_DEPTH): temp_list.append([]) temp_list = temp_list[0] @@ -530,3 +532,51 @@ class TestDetectFrameworkByDumpJson(unittest.TestCase): result = detect_framework_by_dump_json(file_path) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) mock_logger.error.assert_called_once_with(f"{file_path} must be based on the MindSpore or PyTorch framework.") + + +@dataclass +class DumyDataClass: + a: int + b: str + +class TestIsSaveVariableValid(unittest.TestCase): + def setUp(self): + self.valid_special_types = (int, float, str, bool) + + def test_is_save_variable_valid_DepthExceeded_ReturnsFalse(self): + # 创建一个深度超过 Const.DUMP_MAX_DEPTH 的嵌套结构 + nested_structure = [0] * Const.DUMP_MAX_DEPTH + for _ in range(Const.DUMP_MAX_DEPTH): + nested_structure = [nested_structure] + self.assertFalse(is_save_variable_valid(nested_structure, self.valid_special_types)) + + def test_is_save_variable_valid_ValidSpecialTypes_ReturnsTrue(self): + for valid_type in self.valid_special_types: + self.assertTrue(is_save_variable_valid(valid_type(0), self.valid_special_types)) + + def test_is_save_variable_valid_DataClass_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid(DumyDataClass(1, "test"), self.valid_special_types)) + + def test_is_save_variable_valid_ListWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid([1, 2, 3], self.valid_special_types)) + + def test_is_save_variable_valid_ListWithInvalidElement_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid([1, "test", [1, slice(1)]], self.valid_special_types)) + + def test_is_save_variable_valid_TupleWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid((1, 2, 3), self.valid_special_types)) + + def test_is_save_variable_valid_TupleWithInvalidElement_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid((1, "test", [1, slice(1)]), self.valid_special_types)) + + def test_is_save_variable_valid_DictWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid({"a": 1, "b": "test"}, self.valid_special_types)) + + def test_is_save_variable_valid_DictWithInvalidKey_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid({1: "test"}, self.valid_special_types)) + + def test_is_save_variable_valid_DictWithInvalidValue_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid({"a": [1, slice(1)]}, self.valid_special_types)) + + def test_is_save_variable_valid_None_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid(None, self.valid_special_types)) \ No newline at end of file