diff --git "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" index 74e6ff59ac19bed0746877c4693184c376525a2c..a72e23484c1345084c98b8f3c2989bcf50fe86b8 100644 --- "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" +++ "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" @@ -17,10 +17,15 @@ export PYTHONPATH=$PYTHONPATH:{att_root}/debug/accuracy_tools/ ``` -2. 在工具中加入以下代码使用工具dump模块,启动训练抓取网络所有API信息,目前工具仅支持抓取训练的第一个迭代并且在第一个迭代后会退出训练进程。 +2. 在工具中加入以下代码使用工具dump模块,启动训练抓取网络所有API信息 ``` - from api_accuracy_checker.dump import set_dump_switch + import api_accuracy_checker.dump + ``` + + 目前工具仅支持抓取训练的**第二个迭代**并且在第二个迭代后会报错退出训练进程。报错信息如下,这个报错仅用于停止训练,属于正常现象: + ``` + Exception: Model pretest: exit after iteration 1. ``` ​ dump信息默认会存盘到./路径下,包括前向API信息forward_info_{pid}.json, 反向API信息backward_info_{pid}.json, 调用栈信息stack_info_{pid}.json。真实数据模式下还有forward_real_data和backward_real_data文件夹,里面有每个api输入的具体数值。forward_info与stack_info中的key值一一对应,用户可根据forward_info中API的key在stack_info中查询到其调用栈及代码行位置。 @@ -46,7 +51,7 @@ - + diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py index d5b19ad684780df6a3a97da7bf8ee7e1fcdd336f..1b19415c37261c0f6ec284c690a9af22750637c3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py @@ -1,4 +1,4 @@ from api_accuracy_checker.dump.dump import set_dump_switch - +import api_accuracy_checker.dump.dump_scope __all__ = ['set_dump_switch'] diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 8098f25db0b7f60f4a75d3514bd6b16b8ed0bc18..958c7ce51fb4c2e8ea02738672a8c02b2b91e913 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -38,6 +38,8 @@ def set_dump_switch(switch): class DumpUtil(object): dump_switch = None + target_iter = 1 + call_num = 0 @staticmethod def set_dump_switch(switch): @@ -46,6 +48,16 @@ class DumpUtil(object): @staticmethod def get_dump_switch(): return DumpUtil.dump_switch == "ON" + + @staticmethod + def incr_iter_num_maybe_exit(): + if DumpUtil.call_num == DumpUtil.target_iter: + set_dump_switch("ON") + elif DumpUtil.call_num > DumpUtil.target_iter: + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.target_iter)) + else: + set_dump_switch("OFF") + DumpUtil.call_num += 1 class DumpConst: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 51dbd75d9c87ff42d6730c51af017cf6b6e03fe8..16078173eaa7589880b2a75135aa7523bcdf59eb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -1 +1,16 @@ -# dump范围控制 ———— 李天 \ No newline at end of file +# dump范围控制 +import torch +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter +from api_accuracy_checker.dump.dump import DumpUtil + + +def iter_tracer(func): + def func_wrapper(*args, **kwargs): + DumpUtil.dump_switch = "OFF" + result = func(*args, **kwargs) + DumpUtil.incr_iter_num_maybe_exit() + return result + return func_wrapper + +_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file