diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 5bcc0e78af7ad24ee7e0f73f7e3e5ff86648aa26..55e06a7be13282d9c1acd178f12fd874aa87fec4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -2,6 +2,7 @@ import os import torch from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create +from api_accuracy_checker.common.config import msCheckerConfig class BaseAPIInfo: @@ -57,12 +58,12 @@ class BaseAPIInfo: api_args = self.api_name + '.' + str(self.args_num) from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.forward_path, "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num), self.forward_path, "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.backward_path, "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num), self.backward_path, "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index f9b882f47b107a474b9f888e76fe94defd2b26fe..1d9eda41052af523618183468bf61cdae334c0a5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -23,7 +23,8 @@ class Config: 'error_data_path': str, 'target_iter': list, 'precision': int, - 'white_list': list + 'white_list': list, + 'enable_dataloader': bool } if not isinstance(value, validators.get(key)): raise ValueError(f"{key} must be {validators[key].__name__} type") @@ -55,8 +56,6 @@ class Config: return '\n'.join(f"{key}={value}" for key, value in self.config.items()) def update_config(self, dump_path=None, real_data=False, target_iter=None, white_list=None): - if target_iter is None: - target_iter = self.config.get('target_iter',[1]) args = { "dump_path": dump_path if dump_path else self.config.get("dump_path", './'), "real_data": real_data, diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index f8648eb20ccd7572c39a752db287e8b63a0e096c..1d34db2515e2e4cf2b18dedaf319789219dbfac4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -378,19 +378,18 @@ def modify_dump_path(dump_path, mode): def create_directory(dir_path): """ Function Description: - creating a directory with specified permissions + creating a directory with specified permissions in a thread-safe manner Parameter: dir_path: directory path Exception Description: when invalid data throw exception """ - if not os.path.exists(dir_path): - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) - except OSError as ex: - print_error_log( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) from ex + try: + os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + except OSError as ex: + print_error_log( + 'Failed to create {}. Please check the path permission or disk space. {}'.format(dir_path, str(ex))) + raise CompareException(CompareException.INVALID_PATH_ERROR) from ex def execute_command(cmd): diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 0bd145893e83c00c5aff120b82e537d28f4664eb..ece957347ae8073973e67237ff75572271e55077 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -6,4 +6,5 @@ error_data_path: './' target_iter: [1] precision: 14 white_list: [] +enable_dataloader: True \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 931dcae9f246d1ea264a915851ef0d793eb87d83..7cf35091908f453286109562d906b5c8b55bc777 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -32,6 +32,20 @@ def set_dump_switch(switch): DumpUtil.set_dump_switch(switch) +def start(): + if not DumpUtil.get_dump_switch() and not msCheckerConfig.enable_dataloader: + DumpUtil.incr_iter_num_maybe_exit() + + +def stop(): + DumpUtil.set_dump_switch("OFF") + + +def step(): + if not msCheckerConfig.enable_dataloader: + DumpUtil.call_num += 1 + + class DumpUtil(object): dump_switch = None call_num = 0 @@ -52,7 +66,6 @@ class DumpUtil(object): raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) else: set_dump_switch("OFF") - DumpUtil.call_num += 1 class DumpConst: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 85f555ed75847d7ce67e53a9e2183084476088e0..a297e7235f7b7a196112d8fa857513c4d5027f03 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -10,7 +10,9 @@ def iter_tracer(func): DumpUtil.dump_switch = "OFF" result = func(*args, **kwargs) DumpUtil.incr_iter_num_maybe_exit() + DumpUtil.call_num += 1 return result return func_wrapper -_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file +if msCheckerConfig.enable_dataloader: + _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 69de65912fbc2a0d0ca3ceabe4d15a4f7f65e0c5..70d49251f060f88477695d54f2efef776d776334 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -17,7 +17,7 @@ lock = threading.Lock() def write_api_info_json(api_info): from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path - dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num - 1)) + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) check_path_before_create(dump_path) create_directory(dump_path) rank = api_info.rank diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py index 2574716c408f8c2c5f3433ff7abde698a9a219b5..0dd0898c167704a72a430f4779b13ce87e5dc9f4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py @@ -13,7 +13,7 @@ class TestInfoDump(unittest.TestCase): with patch('api_accuracy_checker.dump.info_dump.write_json') as mock_write_json: write_api_info_json(api_info) rank = os.getpid() - mock_write_json.assert_called_with(f'./step1/backward_info_{rank}.json', api_info.grad_info_struct) + mock_write_json.assert_called_with(f'./step2/backward_info_{rank}.json', api_info.grad_info_struct) def test_write_api_info_json_invalid_type(self): api_info = APIInfo("test_api", True, True, "save_path") @@ -22,7 +22,7 @@ class TestInfoDump(unittest.TestCase): def tearDown(self): rank = os.getpid() - files = [f'./step1/backward_info_{rank}.json'] + files = [f'./step2/backward_info_{rank}.json'] for file in files: if os.path.exists(file): os.remove(file)