diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index e726d2bc277c6b69a3704d794ad1a748f2655387..e2d1830f65334b5a9162e3003c09ff38c9d5e0c9 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -40,8 +40,6 @@ from ..common.file_check_util import FileOpen, change_mode, FileCheckConst, chec forward_init_status = False backward_init_status = False -backward_threading_id = 0 - api_list = [] thread_lock = threading.Lock() pkl_name = "" @@ -100,14 +98,6 @@ def get_tensor_data_info(data, *tensor_args): return DataInfo([], summary_data, str(data.dtype), tuple(data.shape)) -def json_dump_condition(prefix): - cur_threading_id = threading.current_thread().ident - global backward_threading_id - if not backward_threading_id and Const.BACKWARD in prefix: - backward_threading_id = cur_threading_id - return (Const.BACKWARD in prefix and backward_threading_id == cur_threading_id) or 'forward' in prefix - - def dump_tensor(x, prefix, dump_step): if isinstance(x, (tuple, list)) and x: for i, item in enumerate(x): @@ -137,15 +127,14 @@ def dump_data(dump_step, prefix, data_info): global api_list thread_lock.acquire() try: - if json_dump_condition(prefix): - output_path = os.path.join(DumpUtil.dump_data_dir, f'{prefix}.npy') - check_path_length(output_path) - check_path_pattern_vaild(output_path) - if not DumpUtil.summary_only: - np.save(output_path, data_info.save_data) - change_mode(output_path, FileCheckConst.DATA_FILE_AUTHORITY) - api_list.append([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data]) - print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') + output_path = os.path.join(DumpUtil.dump_data_dir, f'{prefix}.npy') + check_path_length(output_path) + check_path_pattern_vaild(output_path) + if not DumpUtil.summary_only: + np.save(output_path, data_info.save_data) + change_mode(output_path, FileCheckConst.DATA_FILE_AUTHORITY) + api_list.append([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data]) + print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') except Exception as e: print_warn_log("Dump data failed, error: {}".format(e)) finally: @@ -170,10 +159,9 @@ def dump_stack_info(name_template): prefix = name_template.format("stack_info") if DumpUtil.dump_switch_mode in Const.DUMP_MODE: - if json_dump_condition(prefix): - complement_set = set(['forward', 'backward', 'input', 'output']) - set(DumpUtil.dump_mode) - if not any(mode in prefix for mode in complement_set): - api_list.append([prefix, stack_str]) + complement_set = set(['forward', 'backward', 'input', 'output']) - set(DumpUtil.dump_mode) + if not any(mode in prefix for mode in complement_set): + api_list.append([prefix, stack_str]) else: api_list.append([prefix, stack_str]) diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py index 92c52bdf286d788a3dad44ed651ee6a9a75203fc..9673c292ba20cef94b926986ddac270f748645e9 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py @@ -43,10 +43,6 @@ class TestDump(unittest.TestCase): self.assertEqual(data_info.dtype, 'torch.float32') self.assertEqual(data_info.shape, (3,)) - def test_json_dump_condition(self): - result = json_dump_condition(self.prefix) - self.assertEqual(result, False) - def test_get_pkl_file_path(self): result = get_pkl_file_path() self.assertEqual(result, "")