From 8f8a40f3606b55d1e311fc98bff8a55cca1433b0 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Wed, 6 Mar 2024 16:31:09 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=A4=9A=E6=AC=A1=E8=B0=83?= =?UTF-8?q?=E7=94=A8start=E5=92=8Cstop=20dump=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../python/ptdbg_ascend/debugger/precision_debugger.py | 8 +++++--- .../ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py | 5 ++--- .../ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py | 1 + 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index f71a9d4c54..2d1613c47b 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -1,8 +1,9 @@ import os +from concurrent.futures import ThreadPoolExecutor import torch from ..common.utils import Const, check_switch_valid, generate_compare_script, check_is_npu, print_error_log, \ CompareException, print_warn_log -from ..dump.dump import DumpUtil, acc_cmp_dump, write_to_disk, get_pkl_file_path, reset_module_count, GLOBAL_THREAD_POOL +from ..dump.dump import DumpUtil, acc_cmp_dump, write_to_disk, get_pkl_file_path, reset_module_count from ..dump.utils import set_dump_path, set_dump_switch_print_info, generate_dump_path_str, \ set_dump_switch_config, set_backward_input from ..overflow_check.utils import OverFlowUtil @@ -108,6 +109,7 @@ class PrecisionDebugger: register_hook_core(instance.hook_func, instance.model) instance.first_start = False DumpUtil.dump_switch = "ON" + DumpUtil.dump_thread_pool = ThreadPoolExecutor() OverFlowUtil.overflow_check_switch = "ON" dump_path_str = generate_dump_path_str() set_dump_switch_print_info("ON", DumpUtil.dump_switch_mode, dump_path_str) @@ -130,8 +132,8 @@ class PrecisionDebugger: dump_path_str = generate_dump_path_str() set_dump_switch_print_info("OFF", DumpUtil.dump_switch_mode, dump_path_str) write_to_disk() - if DumpUtil.is_single_rank: - GLOBAL_THREAD_POOL.shutdown(wait=True) + if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: + DumpUtil.dump_thread_pool.shutdown(wait=True) if check_is_npu() and DumpUtil.dump_switch_mode in [Const.ALL, Const.API_STACK, Const.LIST, Const.RANGE, Const.API_LIST]: generate_compare_script(DumpUtil.dump_data_dir, get_pkl_file_path(), DumpUtil.dump_switch_mode) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index a6b769ff2a..ea61be8636 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -47,7 +47,6 @@ pkl_name = "" rank = os.getpid() + 100000 multi_output_apis = ["_sort_", "npu_flash_attention"] module_count = {} -GLOBAL_THREAD_POOL = ThreadPoolExecutor() class APIList(list): @@ -186,12 +185,12 @@ def dump_data(prefix, data_info): def thread_dump_data(prefix, data_info): - GLOBAL_THREAD_POOL.submit(dump_data, prefix, data_info) + DumpUtil.dump_thread_pool.submit(dump_data, prefix, data_info) def dump_data_by_rank_count(dump_step, prefix, data_info): print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') - if DumpUtil.is_single_rank: + if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: thread_dump_data(prefix, data_info) else: dump_data(prefix, data_info) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index bb0ae82c41..bea18501ae 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -87,6 +87,7 @@ class DumpUtil(object): need_replicate = False summary_mode = "all" is_single_rank = None + dump_thread_pool = None @staticmethod -- Gitee