From af72e5294cd724443c268d1face884567aaba397 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 11 Dec 2023 16:38:50 +0800 Subject: [PATCH 01/10] add dataloader switch --- .../api_accuracy_checker/common/base_api.py | 4 ++-- .../api_accuracy_checker/common/config.py | 10 +++++----- debug/accuracy_tools/api_accuracy_checker/config.yaml | 1 + debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 4 ++++ .../api_accuracy_checker/dump/dump_scope.py | 4 +++- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 5bcc0e78af..3933a1c081 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -57,12 +57,12 @@ class BaseAPIInfo: api_args = self.api_name + '.' + str(self.args_num) from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.forward_path, "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num), self.forward_path, "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.backward_path, "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num), self.backward_path, "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index f9b882f47b..2a746d5059 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -23,7 +23,8 @@ class Config: 'error_data_path': str, 'target_iter': list, 'precision': int, - 'white_list': list + 'white_list': list, + 'enable_dataloader': bool } if not isinstance(value, validators.get(key)): raise ValueError(f"{key} must be {validators[key].__name__} type") @@ -54,14 +55,13 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path=None, real_data=False, target_iter=None, white_list=None): - if target_iter is None: - target_iter = self.config.get('target_iter',[1]) + def update_config(self, dump_path=None, real_data=False, target_iter=None, white_list=None, enable_dataloader=None): args = { "dump_path": dump_path if dump_path else self.config.get("dump_path", './'), "real_data": real_data, "target_iter": target_iter if target_iter else self.config.get("target_iter", [1]), - "white_list": white_list if white_list else self.config.get("white_list", []) + "white_list": white_list if white_list else self.config.get("white_list", []), + "enable_dataloader": enable_dataloader if enable_dataloader else self.config.get("enable_dataloader", True) } for key, value in args.items(): if key in self.config: diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 0bd145893e..ece957347a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -6,4 +6,5 @@ error_data_path: './' target_iter: [1] precision: 14 white_list: [] +enable_dataloader: True \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 931dcae9f2..3384351c68 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -53,6 +53,10 @@ class DumpUtil(object): else: set_dump_switch("OFF") DumpUtil.call_num += 1 + + @staticmethod + def step(): + DumpUtil.call_num += 1 class DumpConst: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 85f555ed75..df85fe30d0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -13,4 +13,6 @@ def iter_tracer(func): return result return func_wrapper -_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file +if msCheckerConfig.enable_dataloader: + DumpUtil.iter_num -= 1 + _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file -- Gitee From 909877d5bf90e750c55fa6feb9b032357d3d87c3 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 11 Dec 2023 18:10:31 +0800 Subject: [PATCH 02/10] bugfix --- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 2 +- debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py | 2 +- debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 3384351c68..04315e4dfa 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -49,7 +49,7 @@ class DumpUtil(object): if DumpUtil.call_num in msCheckerConfig.target_iter: set_dump_switch("ON") elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num)) else: set_dump_switch("OFF") DumpUtil.call_num += 1 diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index df85fe30d0..144a842744 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -14,5 +14,5 @@ def iter_tracer(func): return func_wrapper if msCheckerConfig.enable_dataloader: - DumpUtil.iter_num -= 1 + DumpUtil.call_num -= 1 _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 69de65912f..3b9ceb20c5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -17,7 +17,7 @@ lock = threading.Lock() def write_api_info_json(api_info): from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path - dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num - 1)) + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num)) check_path_before_create(dump_path) create_directory(dump_path) rank = api_info.rank -- Gitee From e945b4293fcc3100cb649a92be95e954b80e60f7 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 11 Dec 2023 18:54:42 +0800 Subject: [PATCH 03/10] bugfix --- .../api_accuracy_checker/common/base_api.py | 10 ++++++---- .../api_accuracy_checker/dump/dump.py | 17 ++++++++++++----- .../api_accuracy_checker/dump/dump_scope.py | 1 + .../api_accuracy_checker/dump/info_dump.py | 5 +++-- .../api_accuracy_checker/run_ut/ut_api_info.py | 5 +++-- 5 files changed, 25 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 3933a1c081..4ad4ca1ceb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -58,13 +58,15 @@ class BaseAPIInfo: from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num), self.forward_path, "rank" + str(self.rank)) - check_path_before_create(forward_real_data_path) - create_directory(forward_real_data_path) + if not os.path.exists(forward_real_data_path): + check_path_before_create(forward_real_data_path) + create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num), self.backward_path, "rank" + str(self.rank)) - check_path_before_create(backward_real_data_path) - create_directory(backward_real_data_path) + if not os.path.exists(backward_real_data_path): + check_path_before_create(backward_real_data_path) + create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') self.args_num += 1 pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 04315e4dfa..a1911f2492 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -32,6 +32,18 @@ def set_dump_switch(switch): DumpUtil.set_dump_switch(switch) +def start(): + DumpUtil.incr_iter_num_maybe_exit() + + +def stop(): + DumpUtil.set_dump_switch("OFF") + + +def step(): + DumpUtil.call_num += 1 + + class DumpUtil(object): dump_switch = None call_num = 0 @@ -52,11 +64,6 @@ class DumpUtil(object): raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num)) else: set_dump_switch("OFF") - DumpUtil.call_num += 1 - - @staticmethod - def step(): - DumpUtil.call_num += 1 class DumpConst: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 144a842744..550984b40e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -10,6 +10,7 @@ def iter_tracer(func): DumpUtil.dump_switch = "OFF" result = func(*args, **kwargs) DumpUtil.incr_iter_num_maybe_exit() + DumpUtil.call_num += 1 return result return func_wrapper diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 3b9ceb20c5..0d9a6224ff 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -18,8 +18,9 @@ def write_api_info_json(api_info): from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num)) - check_path_before_create(dump_path) - create_directory(dump_path) + if not os.path.exists(dump_path): + check_path_before_create(dump_path) + create_directory(dump_path) rank = api_info.rank if isinstance(api_info, ForwardAPIInfo): file_path = os.path.join(dump_path, f'forward_info_{rank}.json') diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py index 7d345ac0ab..cd8986cda6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py @@ -15,8 +15,9 @@ class UtAPIInfo(BaseAPIInfo): single_arg = {} api_args = self.api_name + '.' + str(self.args_num) ut_error_data_path = os.path.join(self.save_path, self.ut_error_data_dir) - check_path_before_create(ut_error_data_path) - create_directory(ut_error_data_path) + if not os.path.exists(ut_error_data_path): + check_path_before_create(ut_error_data_path) + create_directory(ut_error_data_path) file_path = os.path.join(ut_error_data_path, f'{api_args}.pt') self.args_num += 1 pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) -- Gitee From 4ca84f8139f8c877826e7ca73d7035f8781774d2 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 11 Dec 2023 19:17:24 +0800 Subject: [PATCH 04/10] bugfix --- .../api_accuracy_checker/common/base_api.py | 10 ++++------ .../api_accuracy_checker/common/config.py | 5 ++--- .../api_accuracy_checker/dump/info_dump.py | 7 +++---- .../api_accuracy_checker/run_ut/ut_api_info.py | 5 ++--- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 4ad4ca1ceb..3933a1c081 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -58,15 +58,13 @@ class BaseAPIInfo: from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num), self.forward_path, "rank" + str(self.rank)) - if not os.path.exists(forward_real_data_path): - check_path_before_create(forward_real_data_path) - create_directory(forward_real_data_path) + check_path_before_create(forward_real_data_path) + create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num), self.backward_path, "rank" + str(self.rank)) - if not os.path.exists(backward_real_data_path): - check_path_before_create(backward_real_data_path) - create_directory(backward_real_data_path) + check_path_before_create(backward_real_data_path) + create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') self.args_num += 1 pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 2a746d5059..1d9eda4105 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -55,13 +55,12 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path=None, real_data=False, target_iter=None, white_list=None, enable_dataloader=None): + def update_config(self, dump_path=None, real_data=False, target_iter=None, white_list=None): args = { "dump_path": dump_path if dump_path else self.config.get("dump_path", './'), "real_data": real_data, "target_iter": target_iter if target_iter else self.config.get("target_iter", [1]), - "white_list": white_list if white_list else self.config.get("white_list", []), - "enable_dataloader": enable_dataloader if enable_dataloader else self.config.get("enable_dataloader", True) + "white_list": white_list if white_list else self.config.get("white_list", []) } for key, value in args.items(): if key in self.config: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 0d9a6224ff..00cc8e5d0c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -17,10 +17,9 @@ lock = threading.Lock() def write_api_info_json(api_info): from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path - dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num)) - if not os.path.exists(dump_path): - check_path_before_create(dump_path) - create_directory(dump_path) + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num)) + check_path_before_create(dump_path) + create_directory(dump_path) rank = api_info.rank if isinstance(api_info, ForwardAPIInfo): file_path = os.path.join(dump_path, f'forward_info_{rank}.json') diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py index cd8986cda6..7d345ac0ab 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py @@ -15,9 +15,8 @@ class UtAPIInfo(BaseAPIInfo): single_arg = {} api_args = self.api_name + '.' + str(self.args_num) ut_error_data_path = os.path.join(self.save_path, self.ut_error_data_dir) - if not os.path.exists(ut_error_data_path): - check_path_before_create(ut_error_data_path) - create_directory(ut_error_data_path) + check_path_before_create(ut_error_data_path) + create_directory(ut_error_data_path) file_path = os.path.join(ut_error_data_path, f'{api_args}.pt') self.args_num += 1 pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) -- Gitee From 3549443721babe4ad233156aae96b5a74b664884 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 11 Dec 2023 19:32:13 +0800 Subject: [PATCH 05/10] bugfix --- debug/accuracy_tools/api_accuracy_checker/common/base_api.py | 5 +++-- debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 3933a1c081..55e06a7be1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -2,6 +2,7 @@ import os import torch from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create +from api_accuracy_checker.common.config import msCheckerConfig class BaseAPIInfo: @@ -57,12 +58,12 @@ class BaseAPIInfo: api_args = self.api_name + '.' + str(self.args_num) from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num), self.forward_path, "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num), self.forward_path, "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num), self.backward_path, "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num), self.backward_path, "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 00cc8e5d0c..70d49251f0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -17,7 +17,7 @@ lock = threading.Lock() def write_api_info_json(api_info): from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path - dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num)) + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) check_path_before_create(dump_path) create_directory(dump_path) rank = api_info.rank -- Gitee From 798e87654a08dd4aec44c1f8d77b50397f328617 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 12 Dec 2023 11:15:24 +0800 Subject: [PATCH 06/10] thread-safe --- .../api_accuracy_checker/common/utils.py | 15 +++++++-------- .../test/ut/dump/test_info_dump.py | 4 ++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index f8648eb20c..1d34db2515 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -378,19 +378,18 @@ def modify_dump_path(dump_path, mode): def create_directory(dir_path): """ Function Description: - creating a directory with specified permissions + creating a directory with specified permissions in a thread-safe manner Parameter: dir_path: directory path Exception Description: when invalid data throw exception """ - if not os.path.exists(dir_path): - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) - except OSError as ex: - print_error_log( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) from ex + try: + os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + except OSError as ex: + print_error_log( + 'Failed to create {}. Please check the path permission or disk space. {}'.format(dir_path, str(ex))) + raise CompareException(CompareException.INVALID_PATH_ERROR) from ex def execute_command(cmd): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py index 2574716c40..0dd0898c16 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py @@ -13,7 +13,7 @@ class TestInfoDump(unittest.TestCase): with patch('api_accuracy_checker.dump.info_dump.write_json') as mock_write_json: write_api_info_json(api_info) rank = os.getpid() - mock_write_json.assert_called_with(f'./step1/backward_info_{rank}.json', api_info.grad_info_struct) + mock_write_json.assert_called_with(f'./step2/backward_info_{rank}.json', api_info.grad_info_struct) def test_write_api_info_json_invalid_type(self): api_info = APIInfo("test_api", True, True, "save_path") @@ -22,7 +22,7 @@ class TestInfoDump(unittest.TestCase): def tearDown(self): rank = os.getpid() - files = [f'./step1/backward_info_{rank}.json'] + files = [f'./step2/backward_info_{rank}.json'] for file in files: if os.path.exists(file): os.remove(file) -- Gitee From a3ffb34c908ab36b826ee41af26a8977a6e5ee5f Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 12 Dec 2023 15:27:49 +0800 Subject: [PATCH 07/10] update --- debug/accuracy_tools/api_accuracy_checker/common/base_api.py | 4 ++-- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 2 +- debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py | 1 - debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 55e06a7be1..75ccd4a0e7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -58,12 +58,12 @@ class BaseAPIInfo: api_args = self.api_name + '.' + str(self.args_num) from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num), self.forward_path, "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.forward_path, "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num), self.backward_path, "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.backward_path, "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index a1911f2492..b3159383b5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -61,7 +61,7 @@ class DumpUtil(object): if DumpUtil.call_num in msCheckerConfig.target_iter: set_dump_switch("ON") elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num)) + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) else: set_dump_switch("OFF") diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 550984b40e..a297e7235f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -15,5 +15,4 @@ def iter_tracer(func): return func_wrapper if msCheckerConfig.enable_dataloader: - DumpUtil.call_num -= 1 _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 70d49251f0..a288c5b69b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -17,7 +17,7 @@ lock = threading.Lock() def write_api_info_json(api_info): from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path - dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num - 1)) check_path_before_create(dump_path) create_directory(dump_path) rank = api_info.rank -- Gitee From 667851647861367a643801b825d3ea3624bf9190 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 12 Dec 2023 07:54:05 +0000 Subject: [PATCH 08/10] Revert "update" This reverts commit a3ffb34c908ab36b826ee41af26a8977a6e5ee5f. --- debug/accuracy_tools/api_accuracy_checker/common/base_api.py | 4 ++-- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 2 +- debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py | 1 + debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 75ccd4a0e7..55e06a7be1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -58,12 +58,12 @@ class BaseAPIInfo: api_args = self.api_name + '.' + str(self.args_num) from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.forward_path, "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num), self.forward_path, "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.backward_path, "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num), self.backward_path, "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index b3159383b5..a1911f2492 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -61,7 +61,7 @@ class DumpUtil(object): if DumpUtil.call_num in msCheckerConfig.target_iter: set_dump_switch("ON") elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num)) else: set_dump_switch("OFF") diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index a297e7235f..550984b40e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -15,4 +15,5 @@ def iter_tracer(func): return func_wrapper if msCheckerConfig.enable_dataloader: + DumpUtil.call_num -= 1 _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index a288c5b69b..70d49251f0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -17,7 +17,7 @@ lock = threading.Lock() def write_api_info_json(api_info): from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path - dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num - 1)) + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) check_path_before_create(dump_path) create_directory(dump_path) rank = api_info.rank -- Gitee From 61d1396d29e201942ea71b4d365308f05442b674 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 12 Dec 2023 21:34:29 +0800 Subject: [PATCH 09/10] update --- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index a1911f2492..967d789b97 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -33,7 +33,8 @@ def set_dump_switch(switch): def start(): - DumpUtil.incr_iter_num_maybe_exit() + if not DumpUtil.get_dump_switch() and not msCheckerConfig.enable_dataloader: + DumpUtil.incr_iter_num_maybe_exit() def stop(): @@ -41,7 +42,10 @@ def stop(): def step(): - DumpUtil.call_num += 1 + if not msCheckerConfig.enable_dataloader: + DumpUtil.call_num += 1 + else: + print_error_log("The step() is not supported in dataloader mode.") class DumpUtil(object): @@ -61,7 +65,7 @@ class DumpUtil(object): if DumpUtil.call_num in msCheckerConfig.target_iter: set_dump_switch("ON") elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num)) + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) else: set_dump_switch("OFF") -- Gitee From 7f158dc49e565abd1d20f3573d50aeeccc0ddbac Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 12 Dec 2023 22:25:36 +0800 Subject: [PATCH 10/10] update --- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 2 -- debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py | 1 - 2 files changed, 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 967d789b97..7cf3509190 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -44,8 +44,6 @@ def stop(): def step(): if not msCheckerConfig.enable_dataloader: DumpUtil.call_num += 1 - else: - print_error_log("The step() is not supported in dataloader mode.") class DumpUtil(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 550984b40e..a297e7235f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -15,5 +15,4 @@ def iter_tracer(func): return func_wrapper if msCheckerConfig.enable_dataloader: - DumpUtil.call_num -= 1 _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file -- Gitee