diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index e5a6b711004f4b2016cd30d28cdd3e4e15ac93ec..4e7cfbc4f66eed33993d11dc5e5635c77e231dda 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -192,6 +192,25 @@ class CompareException(Exception): return self.error_info +class DumpUtil(object): + dump_switch = None + call_num = 0 + + @staticmethod + def set_dump_switch(switch): + DumpUtil.dump_switch = switch + + @staticmethod + def get_dump_switch(): + return DumpUtil.dump_switch == "ON" + + +class DumpConst: + delimiter = '*' + forward = 'forward' + backward = 'backward' + + class DumpException(CompareException): pass diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py index f3e3fe66364169f8d1617acfd378905e225a52d2..8a25a2463dc7ea1b3b9d913abfec8f233b1dbd5c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py @@ -1,5 +1,4 @@ -from api_accuracy_checker.dump.dump import set_dump_switch import api_accuracy_checker.dump.dump_scope from api_accuracy_checker.common.config import msCheckerConfig -__all__ = ['set_dump_switch'] +__all__ = [] diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index adb0c4b0f34215b93c2bcbcce1d8bbd4bdac877b..135abeeb8037c500b059f9ec931a41b3ff52d5d4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -3,7 +3,7 @@ import os import inspect import torch from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory, DumpException +from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory, DumpException, DumpUtil from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create @@ -69,7 +69,6 @@ class APIInfo: @staticmethod def get_full_save_path(save_path, dir_name, contain_step=False): if contain_step: - from api_accuracy_checker.dump.dump import DumpUtil step_dir = "step" + str(DumpUtil.call_num - 1 if msCheckerConfig.enable_dataloader else DumpUtil.call_num) rank_dir = f"rank{os.getpid()}" return os.path.join(save_path, step_dir, dir_name, rank_dir) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index a3a8fee7472c48e3df19f4dabd326071e5465dd5..d291995b4b61017ec7cbae672102490f2853bc6b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -17,7 +17,7 @@ from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json -from api_accuracy_checker.common.utils import print_error_log, CompareException +from api_accuracy_checker.common.utils import print_error_log, CompareException, DumpUtil, DumpConst from api_accuracy_checker.hook_module.register_hook import initialize_hook from api_accuracy_checker.common.config import msCheckerConfig @@ -43,7 +43,7 @@ def check_dataloader_status(): def start(): check_dataloader_status() if not DumpUtil.get_dump_switch(): - DumpUtil.incr_iter_num_maybe_exit() + incr_iter_num_maybe_exit() def stop(): @@ -56,32 +56,13 @@ def step(): DumpUtil.call_num += 1 -class DumpUtil(object): - dump_switch = None - call_num = 0 - - @staticmethod - def set_dump_switch(switch): - DumpUtil.dump_switch = switch - - @staticmethod - def get_dump_switch(): - return DumpUtil.dump_switch == "ON" - - @staticmethod - def incr_iter_num_maybe_exit(): - if DumpUtil.call_num in msCheckerConfig.target_iter: - set_dump_switch("ON") - elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) - else: - set_dump_switch("OFF") - - -class DumpConst: - delimiter = '*' - forward = 'forward' - backward = 'backward' +def incr_iter_num_maybe_exit(): + if DumpUtil.call_num in msCheckerConfig.target_iter: + set_dump_switch("ON") + elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) + else: + set_dump_switch("OFF") def pretest_info_dump(name, out_feat, module, phase): diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index a297e7235f7b7a196112d8fa857513c4d5027f03..20c9c6f8f857a02226dea17425492811a690c720 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -1,7 +1,8 @@ # dump范围控制 import torch from torch.utils.data.dataloader import _BaseDataLoaderIter -from api_accuracy_checker.dump.dump import DumpUtil +from api_accuracy_checker.common.utils import DumpUtil +from api_accuracy_checker.dump.dump import incr_iter_num_maybe_exit from api_accuracy_checker.common.config import msCheckerConfig @@ -9,7 +10,7 @@ def iter_tracer(func): def func_wrapper(*args, **kwargs): DumpUtil.dump_switch = "OFF" result = func(*args, **kwargs) - DumpUtil.incr_iter_num_maybe_exit() + incr_iter_num_maybe_exit() DumpUtil.call_num += 1 return result return func_wrapper diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 70d49251f060f88477695d54f2efef776d776334..329d99c833704de635f5186d48995210021b03cf 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -7,6 +7,7 @@ from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.common.utils import DumpUtil from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker @@ -15,7 +16,6 @@ lock = threading.Lock() def write_api_info_json(api_info): - from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) check_path_before_create(dump_path) diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump.py index 655e624e809a5cceb406b9fce9df4e4f89efb4ee..bd7f1897e9cde74b82c53f6b9d816f35b2c3cf6f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump.py @@ -1,5 +1,6 @@ import unittest from api_accuracy_checker.dump.dump import * +from api_accuracy_checker.common.utils import DumpUtil class TestDumpUtil(unittest.TestCase): def test_set_dump_switch(self): @@ -20,13 +21,13 @@ class TestDumpUtil(unittest.TestCase): DumpUtil.call_num = 6 with self.assertRaises(Exception): - DumpUtil.incr_iter_num_maybe_exit() + incr_iter_num_maybe_exit() DumpUtil.call_num = 4 - DumpUtil.incr_iter_num_maybe_exit() + incr_iter_num_maybe_exit() self.assertEqual(DumpUtil.dump_switch, "OFF") msCheckerConfig.enable_dataloader = False DumpUtil.call_num = 5 - DumpUtil.incr_iter_num_maybe_exit() + incr_iter_num_maybe_exit() self.assertEqual(DumpUtil.dump_switch, "ON")