diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py
index 591279429e260e4bae75f4db6a29926b0a8f5f88..1fd832fc7307209686dfba261bda7463c631dd5c 100644
--- a/debug/accuracy_tools/msprobe/core/common/const.py
+++ b/debug/accuracy_tools/msprobe/core/common/const.py
@@ -113,9 +113,12 @@ class Const:
RUN_UT = "run_ut"
GRAD_PROBE = "grad_probe"
STRUCTURE = "structure"
+ DUMP_PRECISION_HIGH = "high"
+ DUMP_PRECISION_LOW = "low"
TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE]
DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR, STRUCTURE]
DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD]
+ DUMP_PRECISION_LIST = [DUMP_PRECISION_LOW, DUMP_PRECISION_HIGH]
LEVEL_L0 = "L0"
LEVEL_L1 = "L1"
LEVEL_L2 = "L2"
diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py
index 34c3469cf3b6972174f1634ef7f41ecb9b0d01dd..104ed57deb346077e9acf0b025a108df5508175f 100644
--- a/debug/accuracy_tools/msprobe/core/common_config.py
+++ b/debug/accuracy_tools/msprobe/core/common_config.py
@@ -30,6 +30,7 @@ class CommonConfig:
self.level = json_config.get('level')
self.enable_dataloader = json_config.get('enable_dataloader', False)
self.async_dump = json_config.get("async_dump", False)
+ self.precision = json_config.get("precision", Const.DUMP_PRECISION_HIGH)
self._check_config()
def _check_config(self):
@@ -51,6 +52,10 @@ class CommonConfig:
elif self.async_dump:
logger.warning("async_dump is True, it may cause OOM when dumping large tensor.")
+ if self.precision not in Const.DUMP_PRECISION_LIST:
+ logger.error_log_with_exp("precision is invalid, it should be one of {}".format(Const.DUMP_PRECISION_LIST),
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+
class BaseConfig:
def __init__(self, json_config):
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
index f5c2b89a1b49e881e601e895bdfcdb9e59a51acf..1e8cb322f9f4dc518a10d690168e0b80b84fa18e 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
@@ -65,52 +65,6 @@ class MindsporeDataProcessor(BaseDataProcessor):
def analyze_dtype_in_kwargs(element):
return {"type": "mindspore.dtype", "value": str(element)}
- @staticmethod
- def get_stat_info_sync(data):
- tensor_stat = TensorStatInfo()
- if data.dtype == ms.bool_:
- data_np = data.asnumpy()
- tensor_stat.max = np.max(data_np).item()
- tensor_stat.min = np.min(data_np).item()
- elif not data.shape:
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.copy()
- elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
- data_abs = np.abs(data.asnumpy())
- tensor_stat.max = np.max(data_abs).item()
- tensor_stat.min = np.min(data_abs).item()
- tensor_stat.mean = np.mean(data_abs).item()
- tensor_stat.norm = np.linalg.norm(data_abs).item()
- else:
- if not ops.is_floating_point(data) or data.dtype == ms.float64:
- data = data.to(ms.float32)
- get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
- tensor_stat.max = mint.max(data)
- tensor_stat.min = mint.min(data)
- tensor_stat.mean = mint.mean(data)
- tensor_stat.norm = get_norm_value(data)
- return tensor_stat
-
- @staticmethod
- def get_stat_info_async(data):
- tensor_stat = TensorStatInfo()
- if data.dtype == ms.bool_:
- tensor_stat.max = mint.any(data)
- tensor_stat.min = mint.all(data)
- elif not data.shape:
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.copy()
- elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
- logger.warning("Async dump do not support complex data!")
- return tensor_stat
- else:
- if not ops.is_floating_point(data) or data.dtype == ms.float64:
- data = data.to(ms.float32)
- get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
- tensor_stat.max = mint.max(data)
- tensor_stat.min = mint.min(data)
- tensor_stat.mean = mint.mean(data)
- tensor_stat.norm = get_norm_value(data)
- return tensor_stat
-
@staticmethod
def is_hookable_element(element):
return hasattr(element, "register_hook") and callable(element.register_hook)
@@ -147,14 +101,37 @@ class MindsporeDataProcessor(BaseDataProcessor):
self.api_register.restore_inner_used_api()
tensor_stat = TensorStatInfo()
if data.numel() == 0:
- stat_info = tensor_stat
- else:
+ pass
+ elif data.dtype == ms.bool_:
if self.config.async_dump:
- stat_info = MindsporeDataProcessor.get_stat_info_async(data)
+ tensor_stat.max = mint.any(data)
+ tensor_stat.min = mint.all(data)
else:
- stat_info = MindsporeDataProcessor.get_stat_info_sync(data)
+ data_np = data.asnumpy()
+ tensor_stat.max = np.max(data_np).item()
+ tensor_stat.min = np.min(data_np).item()
+ elif not data.shape:
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.copy()
+ elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
+ if self.config.async_dump:
+ logger.warning("Async dump do not support complex data!")
+ else:
+ data_abs = np.abs(data.asnumpy())
+ tensor_stat.max = np.max(data_abs).item()
+ tensor_stat.min = np.min(data_abs).item()
+ tensor_stat.mean = np.mean(data_abs).item()
+ tensor_stat.norm = np.linalg.norm(data_abs).item()
+ else:
+ if self.config.precision == Const.DUMP_PRECISION_HIGH or not ops.is_floating_point(
+ data) or data.dtype == ms.float64:
+ data = data.to(ms.float32)
+ get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
+ tensor_stat.max = mint.max(data)
+ tensor_stat.min = mint.min(data)
+ tensor_stat.mean = mint.mean(data)
+ tensor_stat.norm = get_norm_value(data)
self.api_register.register_inner_used_api()
- return stat_info
+ return tensor_stat
def analyze_single_element(self, element, suffix_stack):
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
index a4141754981be151fca995bfc91955530bc8d933..398419fea058e81280d10e458cb67cb512f11051 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
@@ -95,29 +95,17 @@ class PytorchDataProcessor(BaseDataProcessor):
return {"type": "torch.dtype", "value": str(element)}
@staticmethod
- def get_stat_info_async(data):
+ def get_stat_info(data, async_dump=False, precision=Const.DUMP_PRECISION_HIGH):
tensor_stat = TensorStatInfo()
- if torch.is_complex(data):
- logger.warning("Async dump do not support complex data!")
+ if data.is_meta:
+ return tensor_stat
+ data_clone = data.detach()
+ if not data_clone.numel() or not data_clone.data_ptr():
return tensor_stat
- elif data.dtype == torch.bool:
- tensor_stat.max = torch.any(data)
- tensor_stat.min = torch.all(data)
- elif not data.shape:
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.clone()
- else:
- if data.dtype == torch.float64 or not data.is_floating_point():
- data = data.float()
- tensor_stat.max = torch.max(data)
- tensor_stat.min = torch.min(data)
- tensor_stat.mean = torch.mean(data)
- tensor_stat.norm = torch.norm(data)
- return tensor_stat
-
- @staticmethod
- def get_stat_info_sync(data):
- tensor_stat = TensorStatInfo()
if torch.is_complex(data):
+ if async_dump:
+ logger.warning("Async dump do not support complex data!")
+ return tensor_stat
data_np = data.cpu().numpy()
data_abs = np.abs(data_np)
tensor_stat.max = np.max(data_abs).item()
@@ -129,7 +117,7 @@ class PytorchDataProcessor(BaseDataProcessor):
elif not data.shape:
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.clone()
else:
- if data.dtype == torch.float64 or not data.is_floating_point():
+ if precision == Const.DUMP_PRECISION_HIGH or data.dtype == torch.float64 or not data.is_floating_point():
data = data.float()
tensor_stat.max = torch.max(data)
tensor_stat.min = torch.min(data)
@@ -137,20 +125,6 @@ class PytorchDataProcessor(BaseDataProcessor):
tensor_stat.norm = torch.norm(data)
return tensor_stat
- @staticmethod
- def get_stat_info(data, async_dump=False):
- tensor_stat = TensorStatInfo()
- if data.is_meta:
- return tensor_stat
- data_clone = data.detach()
- if not data_clone.numel() or not data_clone.data_ptr():
- return tensor_stat
- else:
- if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
- return PytorchDataProcessor.get_stat_info_sync(data_clone)
- else:
- return PytorchDataProcessor.get_stat_info_async(data_clone)
-
@staticmethod
def handle_tensor_extremum_nan_inf(tensor, operator):
data_clone = tensor.detach()
@@ -256,7 +230,7 @@ class PytorchDataProcessor(BaseDataProcessor):
return p2pop_info
def _analyze_tensor(self, tensor, suffix):
- tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
+ tensor_stat = self.get_stat_info(tensor, self.config.async_dump, self.config.precision)
tensor_json = {}
tensor_json.update({'type': 'torch.Tensor'})
tensor_json.update({'dtype': str(tensor.dtype)})
diff --git a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md
index d388d697fada86ae8378fb4b277e414b90c665bf..e00b9c86bd6fa1ba7d2b54292859fb6fb583172d 100644
--- a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md
+++ b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md
@@ -10,15 +10,16 @@
### 1.1 通用配置
-| 参数 | 解释 | 是否必选 |
-| ----------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- |
+| 参数 | 解释 | 是否必选 |
+|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- |
| task | dump 的任务类型,str 类型。可选参数:
"statistics":仅采集统计信息,默认值;
"tensor":采集统计信息和完全复刻整网的真实数据;
"run_ut":精度预检,仅 PyTorch 场景支持,采集数据时勿选;
"overflow_check":溢出检测;
"free_benchmark":无标杆比对,不支持 MSAdapter 场景;
"grad_probe":梯度监控, 不支持 MSAdapter 场景;
"structure":仅采集模型结构以及调用栈信息,不采集具体数据。
根据 task 参数取值的不同,可以配置不同场景参数,详见:
[1.2 task 配置为 statistics](#12-task-配置为-statistics),
[1.3 task 配置为 tensor](#13-task-配置为-tensor),
[1.4 task 配置为 run_ut](#14-task-配置为-run_ut),
[1.5 task 配置为 overflow_check](#15-task-配置为-overflow_check),
[1.6 task 配置为 free_benchmark](#16-task-配置为-free_benchmark),
[1.7 task 配置为 grad_probe](#17-task-配置为-grad_probe),
[1.8 task 配置为 structure](#18-task-配置为-structure)。
**配置示例**:"task": "tensor"。 | 否 |
| dump_path | 设置 dump 数据目录路径,str 类型。
**配置示例**:"dump_path": "./dump_path"。 | 是 |
| rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型,默认未配置(表示采集所有卡的数据),应配置元素为 ≥0 的整数或类似"4-6"的字符串,且须配置实际可用的 Rank ID。
PyTorch 场景: Rank ID 从 0 开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的 Rank ID,则 dump 数据为空,比如当前环境 Rank ID 为 0 到 7,实际训练运行 0 到 3 卡,此时若配置 Rank ID 为 4 或不存在的 10 等其他值,dump 数据为空。
MindSpore 场景:所有节点的 Rank ID 均从 0 开始计数,最大取值为每个节点可用卡总数-1,config.json 配置一次 rank 参数对所有节点同时生效。静态图 L0 级别 dump 暂不支持指定rank。
注意,单卡训练时,rank必须为[],即空列表,不能指定rank。
**配置示例**:"rank": [1, "4-6"]。 | 否 |
| step | 指定采集某个 step 的数据,list[Union[int, str]] 类型。默认未配置,表示采集所有 step 数据。采集特定 step 时,须指定为训练脚本中存在的 step,可逐个配置,也可以指定范围。
**配置示例**:"step": [0, 1 , 2, "4-6"]。 | 否 |
| level | dump 级别,str 类型,根据不同级别采集不同数据。可选参数:
"L0":dump 模块级精度数据,使用背景详见 [1.1.1 模块级精度数据 dump 说明](#111-模块级精度数据-dump-说明)。
"L1":dump API 级精度数据,默认值,仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持。
"L2":dump kernel 级精度数据,PyTorch 场景详细介绍见 [PyTorch 场景的 kernel dump 说明](./04.kernel_dump_PyTorch.md);MindSpore 动态图场景详细介绍见 [MindSpore 动态图场景的 kernel dump 说明](./28.kernel_dump_MindSpore.md);MindSpore 静态图场景详细介绍见《MindSpore 场景的数据采集》中的 ["**8.1 静态图场景**"](./06.data_dump_MindSpore.md#81-静态图场景)小节。
"mix":dump module 模块级和 API 级精度数据,即"L0"+"L1",仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持。
"debug":单点保存功能,详见[单点保存工具](./28.debugger_save_instruction.md)。
**配置示例**:"level": "L1"。 | 否 |
| enable_dataloader | 自动控制开关,bool 类型,仅 PyTorch 场景支持。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后自动识别 step 参数指定的迭代,并在该迭代执行完成后退出训练,此时 start、stop 和 step 函数可不配置,开启该开关要求训练脚本是通过 torch.utils.data.dataloader 方式加载数据。仅支持 PyTorch 单卡训练使用,分布式训练场景下存在数据 dump 不全问题。 **这个特性下个版本将被废弃** | 否 |
-| async_dump | 异步 dump 开关,bool 类型, 支持 task 为 tensor 或 statistic 模式, level 支持 L0、 L1、 mix、 debug 模式。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后开启异步 dump,即采集的精度数据会在当前 step 训练结束后统一落盘,训练过程中工具不触发同步操作。由于使用该模式有**显存溢出**的风险,当 task 配置为 tensor 时,即真实数据的异步dump模式,必须配置 [list](#13-task-配置为-tensor) 参数,指定需要 dump 的 tensor 。该模式下,summary_mode 不支持 md5 值,也不支持复数类型 tensor 的统计量计算。
| 否 |
+| async_dump | 异步 dump 开关,bool 类型, 支持 task 为 tensor 或 statistic 模式, level 支持 L0、 L1、 mix、 debug 模式。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后开启异步 dump,即采集的精度数据会在当前 step 训练结束后统一落盘,训练过程中工具不触发同步操作。由于使用该模式有**显存溢出**的风险,当 task 配置为 tensor 时,即真实数据的异步dump模式,必须配置 [list](#13-task-配置为-tensor) 参数,指定需要 dump 的 tensor 。该模式下,summary_mode 不支持 md5 值,也不支持复数类型 tensor 的统计量计算。
| 否 |
+| precision | 控制统计值计算所用精度,可选值["high", "low"],默认值为"high"。选择"high"时,统计量使用float32进行计算,会增加device内存占用,精度更高,但在处理较大数值时可能会导致**显存溢出**;为"low"时使用与原始数据相同的类型进行计算,device内存占用较少。支持 Pytorch,MindSpore 动态图,MindSpore静态图 O0/O1 场景。支持 task 配置为 statistic 或 tensor, level 配置为 L0,L1,mix,debug。 | 否 |
#### 1.1.1 模块级精度数据 dump 说明
@@ -46,7 +47,6 @@