diff --git a/torch_npu/profiler/analysis/_npu_profiler.py b/torch_npu/profiler/analysis/_npu_profiler.py index cc23309bb0f90330bb1747df5159824e5ab55b91..2cee702bc512bcd41cc50845c0f77c5576fb0d9e 100644 --- a/torch_npu/profiler/analysis/_npu_profiler.py +++ b/torch_npu/profiler/analysis/_npu_profiler.py @@ -1,9 +1,8 @@ import multiprocessing - from .prof_common_func._constant import Constant, print_error_msg -from .prof_common_func._multi_process_pool import MultiProcessPool from .prof_common_func._path_manager import ProfilerPathManager +from .prof_common_func._multi_process_pool import MultiProcessPool from ._profiling_parser import ProfilingParser from ...utils._path_manager import PathManager @@ -34,11 +33,21 @@ class NpuProfiler: len(profiler_path_list)) MultiProcessPool().init_pool(process_number) + futures = [] for profiler_path in profiler_path_list: PathManager.check_directory_path_writeable(profiler_path) profiling_parser = ProfilingParser(profiler_path, analysis_type, output_path, kwargs) - MultiProcessPool().submit_task(profiling_parser.analyse_profiling_data) + future = MultiProcessPool().submit_task(profiling_parser.analyse_profiling_data) + futures.append(future) if not kwargs.get("async_mode", False): + timeout = 30 + if timeout: + import concurrent.futures + done, not_done = concurrent.futures.wait(futures, timeout=timeout) + if not_done: + print(f"Profiling analysis timed out after {timeout} seconds") + MultiProcessPool().close_pool(wait=False) + return MultiProcessPool().close_pool() @classmethod diff --git a/torch_npu/profiler/analysis/prof_common_func/_multi_process_pool.py b/torch_npu/profiler/analysis/prof_common_func/_multi_process_pool.py index 3f1031a9b3cf0626b31ace2ef41e0f06f221a849..5a3d58e91d821f88d5de8f2c604c9385185f8c41 100644 --- a/torch_npu/profiler/analysis/prof_common_func/_multi_process_pool.py +++ b/torch_npu/profiler/analysis/prof_common_func/_multi_process_pool.py @@ -1,4 +1,6 @@ import atexit +import os +import signal from concurrent import futures from ._singleton import Singleton @@ -8,17 +10,38 @@ from ._singleton import Singleton class MultiProcessPool: def __init__(self): self._pool = None + self._processes = set() atexit.register(self.close_pool) def init_pool(self, max_workers=4): if not self._pool: self._pool = futures.ProcessPoolExecutor(max_workers=max_workers) + # 捕获进程创建事件以跟踪PID + self._pool._mp_context.Process = self._track_process return self._pool + def _track_process(self, *args, **kwargs): + process = futures.ProcessPoolExecutor._mp_context.Process(*args, **kwargs) + self._processes.add(process) + return process + def close_pool(self, wait: bool = True): if self._pool: self._pool.shutdown(wait=wait) self._pool = None + # 强制终止所有未退出的子进程 + for process in self._processes: + if process.is_alive(): + try: + os.kill(process.pid, signal.SIGTERM) + # 等待进程终止 + process.join(timeout=5) + # 如果仍未终止则强制杀死 + if process.is_alive(): + os.kill(process.pid, signal.SIGKILL) + except (OSError): + pass + self._processes.clear() def submit_task(self, func, *args, **kwargs): if not self._pool: