From 49b9f93451f146150b1046fd548939f54ba36665 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 8 Aug 2023 09:33:00 +0000 Subject: [PATCH 1/2] update debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py. Signed-off-by: sunyiming --- .../api_accuracy_checker/dump/dump_scope.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 51dbd75d9c8..df42463e871 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -1 +1,17 @@ -# dump范围控制 ———— 李天 \ No newline at end of file +import torch +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter + + +from api_accuracy_checker.dump.dump import DumpUtil + +_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) + + +def iter_tracer(func): + def func_wrapper(*args, **kwargs): + DumpUtil.dump_switch = "OFF " + result = func(*args, **kwargs) + DumpUtil.incr_iter_num_maybe_exit() + return result + return func_wrapper \ No newline at end of file -- Gitee From 442272334d32e9724a9636dd840b1b43acf91def Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 8 Aug 2023 09:35:06 +0000 Subject: [PATCH 2/2] update debug/accuracy_tools/api_accuracy_checker/dump/dump.py. Signed-off-by: sunyiming --- .../accuracy_tools/api_accuracy_checker/dump/dump.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 8098f25db0b..61fd5f93a59 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -38,6 +38,8 @@ def set_dump_switch(switch): class DumpUtil(object): dump_switch = None + target_iter_range = 1 + call_num = 0 @staticmethod def set_dump_switch(switch): @@ -47,6 +49,16 @@ class DumpUtil(object): def get_dump_switch(): return DumpUtil.dump_switch == "ON" + @staticmethod + def incr_iter_num_maybe_exit(): + if DumpUtil.call_num == DumpUtil.target_iter_range : + DumpUtil.dump_switch = "ON" + elif DumpUtil.call_num > DumpUtil.target_iter_range: + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.target_iter_range)) + else: + DumpUtil.dump_switch = "OFF" + DumpUtil.call_num += 1 + class DumpConst: delimiter = '*' -- Gitee