diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 8098f25db0b7f60f4a75d3514bd6b16b8ed0bc18..61fd5f93a59cbec538857b642639ef543af008cf 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -38,6 +38,8 @@ def set_dump_switch(switch): class DumpUtil(object): dump_switch = None + target_iter_range = 1 + call_num = 0 @staticmethod def set_dump_switch(switch): @@ -47,6 +49,16 @@ class DumpUtil(object): def get_dump_switch(): return DumpUtil.dump_switch == "ON" + @staticmethod + def incr_iter_num_maybe_exit(): + if DumpUtil.call_num == DumpUtil.target_iter_range : + DumpUtil.dump_switch = "ON" + elif DumpUtil.call_num > DumpUtil.target_iter_range: + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.target_iter_range)) + else: + DumpUtil.dump_switch = "OFF" + DumpUtil.call_num += 1 + class DumpConst: delimiter = '*' diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 51dbd75d9c87ff42d6730c51af017cf6b6e03fe8..df42463e8715eadfe5aca2cb21258d6a2a7fcdd3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -1 +1,17 @@ -# dump范围控制 ———— 李天 \ No newline at end of file +import torch +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter + + +from api_accuracy_checker.dump.dump import DumpUtil + +_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) + + +def iter_tracer(func): + def func_wrapper(*args, **kwargs): + DumpUtil.dump_switch = "OFF " + result = func(*args, **kwargs) + DumpUtil.incr_iter_num_maybe_exit() + return result + return func_wrapper \ No newline at end of file