diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py index d2c12a0e52d3a93c9f0c1b9ff87a6d1feac5e033..73ec7988532c71009da691851e68d163dbe8dbfb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py @@ -1,10 +1,4 @@ -from api_accuracy_checker.hook_module.register_hook import initialize_hook -from api_accuracy_checker.dump.dump import pretest_hook -from api_accuracy_checker.dump.info_dump import initialize_output_json from api_accuracy_checker.dump.utils import set_dump_switch -initialize_hook(pretest_hook) -initialize_output_json() - -__all__ = ['set_dump_switch', 'msCheckerConfig'] \ No newline at end of file +__all__ = ['set_dump_switch'] diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index e69c4b50ef548b51a816f5a5c2bfd9c50b4ae2e4..8098f25db0b7f60f4a75d3514bd6b16b8ed0bc18 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -25,12 +25,37 @@ import threading from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo -from api_accuracy_checker.dump.info_dump import write_api_info_json -from api_accuracy_checker.dump.utils import DumpConst, DumpUtil +from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json from api_accuracy_checker.common.utils import print_warn_log, print_info_log, print_error_log +from api_accuracy_checker.hook_module.register_hook import initialize_hook + + +def set_dump_switch(switch): + if switch == "ON": + initialize_hook(pretest_hook) + initialize_output_json() + DumpUtil.set_dump_switch(switch) + +class DumpUtil(object): + dump_switch = None + + @staticmethod + def set_dump_switch(switch): + DumpUtil.dump_switch = switch + + @staticmethod + def get_dump_switch(): + return DumpUtil.dump_switch == "ON" + + +class DumpConst: + delimiter = '*' + forward = 'forward' + backward = 'backward' + def pretest_info_dump(name, out_feat, module, phase): - if not DumpUtil.dump_switch: + if not DumpUtil.get_dump_switch(): return if phase == DumpConst.forward: api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 4a19785b61dd32a7e11f2ea313fab52f8f8cd7d0..93af6f0981aa0bd3d3c42606a50f04b79bc1c37b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -1,15 +1,6 @@ import os -import shutil -import sys -from pathlib import Path import numpy as np -from api_accuracy_checker.common.utils import print_error_log, CompareException, DumpException, Const, get_time, print_info_log, \ - check_mode_valid, get_api_name_from_matcher -class DumpConst: - delimiter = '*' - forward = 'forward' - backward = 'backward' def create_folder(path): if not os.path.exists(path): @@ -22,17 +13,3 @@ def write_npy(file_path, tensor): np.save(file_path, tensor) full_path = os.path.abspath(file_path) return full_path - -def set_dump_switch(switch): - DumpUtil.set_dump_switch(switch) - -class DumpUtil(object): - dump_switch = None - - @staticmethod - def set_dump_switch(switch): - DumpUtil.dump_switch = switch - - @staticmethod - def get_dump_switch(): - return DumpUtil.dump_switch == "ON" \ No newline at end of file