From 3036416d3d4a2eb9820168e4da697684bf605635 Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 29 Dec 2023 12:03:17 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=BC=95=E7=94=A8?= =?UTF-8?q?=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/common/base_api.py | 2 +- .../api_accuracy_checker/dump/dump.py | 29 +------------------ .../api_accuracy_checker/dump/dump_utils.py | 29 +++++++++++++++++++ .../api_accuracy_checker/dump/info_dump.py | 2 +- 4 files changed, 32 insertions(+), 30 deletions(-) create mode 100644 debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 4d1ebff9198..e76250d1dd3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -3,6 +3,7 @@ import torch from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.dump.dump_utils import DumpUtil class BaseAPIInfo: @@ -55,7 +56,6 @@ class BaseAPIInfo: single_arg.update({'requires_grad': arg.requires_grad}) else: api_args = self.api_name + '.' + str(self.args_num) - from api_accuracy_checker.dump.dump import DumpUtil step_dir = "step" + str(DumpUtil.call_num - 1 if msCheckerConfig.enable_dataloader else DumpUtil.call_num) rank_dir = f"rank{self.rank}" if self.is_forward: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index a3a8fee7472..893529d14b7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -20,6 +20,7 @@ from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_ from api_accuracy_checker.common.utils import print_error_log, CompareException from api_accuracy_checker.hook_module.register_hook import initialize_hook from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.dump.dump_utils import DumpUtil, DumpConst def set_dump_switch(switch): @@ -56,34 +57,6 @@ def step(): DumpUtil.call_num += 1 -class DumpUtil(object): - dump_switch = None - call_num = 0 - - @staticmethod - def set_dump_switch(switch): - DumpUtil.dump_switch = switch - - @staticmethod - def get_dump_switch(): - return DumpUtil.dump_switch == "ON" - - @staticmethod - def incr_iter_num_maybe_exit(): - if DumpUtil.call_num in msCheckerConfig.target_iter: - set_dump_switch("ON") - elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) - else: - set_dump_switch("OFF") - - -class DumpConst: - delimiter = '*' - forward = 'forward' - backward = 'backward' - - def pretest_info_dump(name, out_feat, module, phase): if not DumpUtil.get_dump_switch(): return diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py new file mode 100644 index 00000000000..c89b851c77b --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py @@ -0,0 +1,29 @@ +from api_accuracy_checker.common.config import msCheckerConfig + + +class DumpUtil(object): + dump_switch = None + call_num = 0 + + @staticmethod + def set_dump_switch(switch): + DumpUtil.dump_switch = switch + + @staticmethod + def get_dump_switch(): + return DumpUtil.dump_switch == "ON" + + @staticmethod + def incr_iter_num_maybe_exit(): + if DumpUtil.call_num in msCheckerConfig.target_iter: + set_dump_switch("ON") + elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) + else: + set_dump_switch("OFF") + + +class DumpConst: + delimiter = '*' + forward = 'forward' + backward = 'backward' \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 70d49251f06..3177497cde3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -7,6 +7,7 @@ from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.dump.dump_utils import DumpUtil from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker @@ -15,7 +16,6 @@ lock = threading.Lock() def write_api_info_json(api_info): - from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) check_path_before_create(dump_path) -- Gitee From 31c4ddbaa88b9ee03173c9f3213897f01e41ad5e Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 29 Dec 2023 15:21:21 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E5=88=A0=E9=99=A4init.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/dump/__init__.py | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 debug/accuracy_tools/api_accuracy_checker/dump/__init__.py diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py deleted file mode 100644 index f3e3fe66364..00000000000 --- a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from api_accuracy_checker.dump.dump import set_dump_switch -import api_accuracy_checker.dump.dump_scope -from api_accuracy_checker.common.config import msCheckerConfig - -__all__ = ['set_dump_switch'] -- Gitee From 9fb275eb7119dd1f6478c472172f91c7baa9b7d9 Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 29 Dec 2023 17:52:18 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E4=BF=AE=E6=94=B9=5F=5Finit=5F=5F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/dump/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 debug/accuracy_tools/api_accuracy_checker/dump/__init__.py diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py new file mode 100644 index 00000000000..8a25a2463dc --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py @@ -0,0 +1,4 @@ +import api_accuracy_checker.dump.dump_scope +from api_accuracy_checker.common.config import msCheckerConfig + +__all__ = [] -- Gitee From 422eafb7d151c0450f64e3ae84deab5f098312f5 Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 29 Dec 2023 18:01:22 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E4=BF=AE=E6=94=B9=5F=5Finit=5F=5F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index a297e7235f7..674089887d1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -1,7 +1,7 @@ # dump范围控制 import torch from torch.utils.data.dataloader import _BaseDataLoaderIter -from api_accuracy_checker.dump.dump import DumpUtil +from api_accuracy_checker.dump.dump_utils import DumpUtil from api_accuracy_checker.common.config import msCheckerConfig -- Gitee From 53c1516b0c405dce9a963c100b0eb00286f898d7 Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 29 Dec 2023 18:36:33 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=96=B9=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/api_accuracy_checker/dump/dump.py | 9 +++++++++ .../api_accuracy_checker/dump/dump_scope.py | 3 ++- .../api_accuracy_checker/dump/dump_utils.py | 12 ------------ 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 893529d14b7..4b3516fd574 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -33,6 +33,15 @@ def set_dump_switch(switch): DumpUtil.set_dump_switch(switch) +def incr_iter_num_maybe_exit(): + if DumpUtil.call_num in msCheckerConfig.target_iter: + set_dump_switch("ON") + elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) + else: + set_dump_switch("OFF") + + def check_dataloader_status(): if msCheckerConfig.enable_dataloader: error_info = ("If you want to use this function, set enable_dataloader " diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 674089887d1..e2edc15bc6d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -2,6 +2,7 @@ import torch from torch.utils.data.dataloader import _BaseDataLoaderIter from api_accuracy_checker.dump.dump_utils import DumpUtil +from api_accuracy_checker.dump.dump import incr_iter_num_maybe_exit from api_accuracy_checker.common.config import msCheckerConfig @@ -9,7 +10,7 @@ def iter_tracer(func): def func_wrapper(*args, **kwargs): DumpUtil.dump_switch = "OFF" result = func(*args, **kwargs) - DumpUtil.incr_iter_num_maybe_exit() + incr_iter_num_maybe_exit() DumpUtil.call_num += 1 return result return func_wrapper diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py index c89b851c77b..1ab8bdeace4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_utils.py @@ -1,6 +1,3 @@ -from api_accuracy_checker.common.config import msCheckerConfig - - class DumpUtil(object): dump_switch = None call_num = 0 @@ -13,15 +10,6 @@ class DumpUtil(object): def get_dump_switch(): return DumpUtil.dump_switch == "ON" - @staticmethod - def incr_iter_num_maybe_exit(): - if DumpUtil.call_num in msCheckerConfig.target_iter: - set_dump_switch("ON") - elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) - else: - set_dump_switch("OFF") - class DumpConst: delimiter = '*' -- Gitee