diff --git a/debug/accuracy_tools/msprobe/README.md b/debug/accuracy_tools/msprobe/README.md index 7a20b6945216b84ffa6baf1014cb34a82b023e4b..b41b088fe58519a3568eceee7f01efa29ef0ebbd 100644 --- a/debug/accuracy_tools/msprobe/README.md +++ b/debug/accuracy_tools/msprobe/README.md @@ -9,7 +9,8 @@ 为方便使用,本工具提供了一个统一、简易的程序接口,**PrecisionDebugger**,以 PyTorch 框架为例,通过以下示例模板和 **config.json** 可轻松使用各种功能。 ```python -from msprobe.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger, seed_all +seed_all(mode=False) debugger = PrecisionDebugger(config_path='./config.json') ... diff --git a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md index 79be472aaff56ba129faaa83ab42c46249d880c4..0ef7666b3d113e036f9d718f3a5a3c11e7a9b82d 100644 --- a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md @@ -26,7 +26,20 @@ PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, model 1. config_path:指定 dump 配置文件路径;model:指定具体的 torch.nn.Module,默认未配置,level 配置为"L0"或"mix"时必须配置该参数。其他参数均在 [config.json](../config.json) 文件中可配,详细配置可见 [config.json 介绍](./02.config_introduction.md)。 2. 此接口的参数均不是必要,且优先级高于 [config.json](../config.json) 文件中的配置,但可配置的参数相比 config.json 较少。 -### 1.2 start +### 1.2 seed_all + +**功能说明**:固定随机种子和确定性。在工具导入之后的位置添加。 + +**原型**: + +```Python +seed_all(mode, seed=1234) +``` + +1. mode:是否开启确定性计算,bool类型,True表示开启确定性计算,False表示不开启确定性计算 +2. seed: 随机种子,int类型,默认为1234 + +### 1.3 start **功能说明**:启动精度数据采集。在模型初始化之后的位置添加。需要与 stop 函数一起添加在 for 循环内。 @@ -36,7 +49,7 @@ PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, model debugger.start() ``` -### 1.3 stop +### 1.4 stop **功能说明**:停止精度数据采集。在 **start** 函数之后的任意位置添加。若需要 dump 反向数据,则需要添加在反向计算代码(如,loss.backward)之后。使用示例可参见 [2.1 快速上手](#21-快速上手)和 [2.2 采集完整的前反向数据](#22-采集完整的前反向数据)。 @@ -46,7 +59,7 @@ debugger.start() debugger.stop() ``` -### 1.4 forward_backward_dump_end +### 1.5 forward_backward_dump_end **功能说明**:停止精度数据采集。用于 dump 指定代码的前反向数据。在 **start** 函数之后,反向计算代码(如,loss.backward)之前的任意位置添加,可以采集 **start** 函数和该函数之间的前反向数据,可以通过调整 **start** 函数与该函数的位置,来指定需要 dump 的代码块。要求 **stop** 函数添加在反向计算代码(如,loss.backward)之后,此时该函数与 **stop** 函数之间的代码不会被 dump。使用示例可参见 [2.3 采集指定代码块的前反向数据](#23-采集指定代码块的前反向数据) @@ -56,7 +69,7 @@ debugger.stop() forward_backward_dump_end() ``` -### 1.5 step +### 1.6 step **功能说明**:更新 dump 参数。在最后一个 **stop** 函数后或一个 step 结束的位置添加。需要与 **start** 函数一起添加在 for 循环内。 @@ -78,7 +91,8 @@ import torch import torch.nn as nn import torch_npu # 需安装 torch_npu import torch.nn.functional as F -from msprobe.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger, seed_all +seed_all(mode=False) torch.npu.set_device("npu:0") @@ -110,7 +124,8 @@ if __name__ == "__main__": ### 2.2 采集完整的前反向数据 ```Python -from msprobe.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger, seed_all +seed_all(mode=False) # 请勿将PrecisionDebugger的初始化流程插入到循环代码中 debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path") @@ -130,7 +145,8 @@ for data, label in data_loader: ### 2.3 采集指定代码块的前反向数据 ```Python -from msprobe.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger, seed_all +seed_all(mode=False) # 请勿将PrecisionDebugger的初始化流程插入到循环代码中 debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path") diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index c4e426772670212382addb9b855b4bdf69810d3d..465eabb17f0cee2835b7afeb60b204942e49a42d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -1,4 +1,4 @@ from .debugger.precision_debugger import PrecisionDebugger from .common.utils import seed_all from .compare.distributed_compare import compare_distributed -from .compare.pt_compare import compare \ No newline at end of file +from .compare.pt_compare import compare diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index c3bc033d1e107f54eecb1fa1adafde2f40c8447c..7241445ca81dda9248d79e98f06ebbfd35ac7f27 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -104,7 +104,11 @@ def get_rank_if_initialized(): raise DistributedNotInitializedError("torch distributed environment is not initialized") -def seed_all(seed=1234, mode=False): +def seed_all(mode, seed=1234): + if not isinstance(mode, bool): + raise ValueError(f"Invalid input parameter 'mode', the expected type bool but got {type(mode)}.") + if not isinstance(seed, int): + raise ValueError(f"Invalid input parameter 'seed', the expected type int but got {type(seed)}.") random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py index 979bc90802858c7c05948756661420cb99294ff7..0e82beef93dd376c07d4efb8e09227df927399c9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py @@ -56,7 +56,6 @@ class DebuggerConfig: for index, scope_spec in enumerate(self.scope): self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD) self.backward_input[self.scope[index]] = self.backward_input_list[index] - seed_all(self.seed, self.is_deterministic) def check_kwargs(self): if self.task and self.task not in Const.TASK_LIST: