diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index f71a9d4c545fed7138c88eb87e91bf19d83609b1..2d1613c47b12063b32dd9c46edaaf794b17c2053 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -1,8 +1,9 @@ import os +from concurrent.futures import ThreadPoolExecutor import torch from ..common.utils import Const, check_switch_valid, generate_compare_script, check_is_npu, print_error_log, \ CompareException, print_warn_log -from ..dump.dump import DumpUtil, acc_cmp_dump, write_to_disk, get_pkl_file_path, reset_module_count, GLOBAL_THREAD_POOL +from ..dump.dump import DumpUtil, acc_cmp_dump, write_to_disk, get_pkl_file_path, reset_module_count from ..dump.utils import set_dump_path, set_dump_switch_print_info, generate_dump_path_str, \ set_dump_switch_config, set_backward_input from ..overflow_check.utils import OverFlowUtil @@ -108,6 +109,7 @@ class PrecisionDebugger: register_hook_core(instance.hook_func, instance.model) instance.first_start = False DumpUtil.dump_switch = "ON" + DumpUtil.dump_thread_pool = ThreadPoolExecutor() OverFlowUtil.overflow_check_switch = "ON" dump_path_str = generate_dump_path_str() set_dump_switch_print_info("ON", DumpUtil.dump_switch_mode, dump_path_str) @@ -130,8 +132,8 @@ class PrecisionDebugger: dump_path_str = generate_dump_path_str() set_dump_switch_print_info("OFF", DumpUtil.dump_switch_mode, dump_path_str) write_to_disk() - if DumpUtil.is_single_rank: - GLOBAL_THREAD_POOL.shutdown(wait=True) + if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: + DumpUtil.dump_thread_pool.shutdown(wait=True) if check_is_npu() and DumpUtil.dump_switch_mode in [Const.ALL, Const.API_STACK, Const.LIST, Const.RANGE, Const.API_LIST]: generate_compare_script(DumpUtil.dump_data_dir, get_pkl_file_path(), DumpUtil.dump_switch_mode) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index a6b769ff2a41955aa6363f08d086a091335bf5f1..ea61be86364504885a3db240e60ddd60d360564b 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -47,7 +47,6 @@ pkl_name = "" rank = os.getpid() + 100000 multi_output_apis = ["_sort_", "npu_flash_attention"] module_count = {} -GLOBAL_THREAD_POOL = ThreadPoolExecutor() class APIList(list): @@ -186,12 +185,12 @@ def dump_data(prefix, data_info): def thread_dump_data(prefix, data_info): - GLOBAL_THREAD_POOL.submit(dump_data, prefix, data_info) + DumpUtil.dump_thread_pool.submit(dump_data, prefix, data_info) def dump_data_by_rank_count(dump_step, prefix, data_info): print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') - if DumpUtil.is_single_rank: + if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: thread_dump_data(prefix, data_info) else: dump_data(prefix, data_info) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index bb0ae82c4110d568e31cd21a261f5875548c8672..bea18501aefaba4916cba7ad3b7e455d0c29033c 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -87,6 +87,7 @@ class DumpUtil(object): need_replicate = False summary_mode = "all" is_single_rank = None + dump_thread_pool = None @staticmethod