diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index a297e7235f7b7a196112d8fa857513c4d5027f03..1f65dbc9c8a7e482d8ac85e3d06cffc3b11b406a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -5,14 +5,18 @@ from api_accuracy_checker.dump.dump import DumpUtil from api_accuracy_checker.common.config import msCheckerConfig -def iter_tracer(func): +def iter_tracer(original_next): def func_wrapper(*args, **kwargs): - DumpUtil.dump_switch = "OFF" - result = func(*args, **kwargs) - DumpUtil.incr_iter_num_maybe_exit() - DumpUtil.call_num += 1 - return result + if msCheckerConfig.enable_dataloader: + DumpUtil.dump_switch = "OFF" + result = original_next(*args, **kwargs) + DumpUtil.incr_iter_num_maybe_exit() + DumpUtil.call_num += 1 + return result + else: + return original_next(*args, **kwargs) return func_wrapper -if msCheckerConfig.enable_dataloader: - _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file +original_next_method = _BaseDataLoaderIter.__next__ + +_BaseDataLoaderIter.__next__ = iter_tracer(original_next_method) \ No newline at end of file