diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 4d1ebff9198f1707bf76555801fd4df154ca8d98..e76250d1dd3af806a3dc87d6fe531da25f216329 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -3,6 +3,7 @@ import torch from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.dump.dump_utils import DumpUtil class BaseAPIInfo: @@ -55,7 +56,6 @@ class BaseAPIInfo: single_arg.update({'requires_grad': arg.requires_grad}) else: api_args = self.api_name + '.' + str(self.args_num) - from api_accuracy_checker.dump.dump import DumpUtil step_dir = "step" + str(DumpUtil.call_num - 1 if msCheckerConfig.enable_dataloader else DumpUtil.call_num) rank_dir = f"rank{self.rank}" if self.is_forward: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py index f3e3fe66364169f8d1617acfd378905e225a52d2..8a25a2463dc7ea1b3b9d913abfec8f233b1dbd5c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py @@ -1,5 +1,4 @@ -from api_accuracy_checker.dump.dump import set_dump_switch import api_accuracy_checker.dump.dump_scope from api_accuracy_checker.common.config import msCheckerConfig -__all__ = ['set_dump_switch'] +__all__ = [] diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index a3a8fee7472c48e3df19f4dabd326071e5465dd5..4b3516fd574ab9330814472295148b0dcaafba37 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -20,6 +20,7 @@ from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_ from api_accuracy_checker.common.utils import print_error_log, CompareException from api_accuracy_checker.hook_module.register_hook import initialize_hook from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.dump.dump_utils import DumpUtil, DumpConst def set_dump_switch(switch): @@ -32,6 +33,15 @@ def set_dump_switch(switch): DumpUtil.set_dump_switch(switch) +def incr_iter_num_maybe_exit(): + if DumpUtil.call_num in msCheckerConfig.target_iter: + set_dump_switch("ON") + elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) + else: + set_dump_switch("OFF") + + def check_dataloader_status(): if msCheckerConfig.enable_dataloader: error_info = ("If you want to use this function, set enable_dataloader " @@ -56,34 +66,6 @@ def step(): DumpUtil.call_num += 1 -class DumpUtil(object): - dump_switch = None - call_num = 0 - - @staticmethod - def set_dump_switch(switch): - DumpUtil.dump_switch = switch - - @staticmethod - def get_dump_switch(): - return DumpUtil.dump_switch == "ON" - - @staticmethod - def incr_iter_num_maybe_exit(): - if DumpUtil.call_num in msCheckerConfig.target_iter: - set_dump_switch("ON") - elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) - else: - set_dump_switch("OFF") - - -class DumpConst: - delimiter = '*' - forward = 'forward' - backward = 'backward' - - def pretest_info_dump(name, out_feat, module, phase): if not DumpUtil.get_dump_switch(): return diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index a297e7235f7b7a196112d8fa857513c4d5027f03..e2edc15bc6da4712db0aafffea75adf1e5b458fb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -1,7 +1,8 @@ # dump范围控制 import torch from torch.utils.data.dataloader import _BaseDataLoaderIter -from api_accuracy_checker.dump.dump import DumpUtil +from api_accuracy_checker.dump.dump_utils import DumpUtil +from api_accuracy_checker.dump.dump import incr_iter_num_maybe_exit from api_accuracy_checker.common.config import msCheckerConfig @@ -9,7 +10,7 @@ def iter_tracer(func): def func_wrapper(*args, **kwargs): DumpUtil.dump_switch = "OFF" result = func(*args, **kwargs) - DumpUtil.incr_iter_num_maybe_exit() + incr_iter_num_maybe_exit() DumpUtil.call_num += 1 return result return func_wrapper diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab8bdeace4fd797b56fd46f2b3fe5f4092cced7 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py @@ -0,0 +1,17 @@ +class DumpUtil(object): + dump_switch = None + call_num = 0 + + @staticmethod + def set_dump_switch(switch): + DumpUtil.dump_switch = switch + + @staticmethod + def get_dump_switch(): + return DumpUtil.dump_switch == "ON" + + +class DumpConst: + delimiter = '*' + forward = 'forward' + backward = 'backward' \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 70d49251f060f88477695d54f2efef776d776334..3177497cde312727a3177837c0b3a5b5ec5572e4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -7,6 +7,7 @@ from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.dump.dump_utils import DumpUtil from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker @@ -15,7 +16,6 @@ lock = threading.Lock() def write_api_info_json(api_info): - from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) check_path_before_create(dump_path)