diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py index bb39809891611ea0b3d17e660f5841849195b1ff..0692b8aad16a972db5208c54b1a90b241116aea7 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py @@ -409,32 +409,48 @@ def read_dump_path(result_path): def _handle_multi_process(func, input_parma, result_path, lock): - process_num = int((multiprocessing.cpu_count() + 1) / 2) + try: + ulimit_output = subprocess.check_output(['ulimit', '-u'], shell=False) + max_user_processes = int(ulimit_output.strip()) + except subprocess.CalledProcessError as e: + print_warn_log(f"Failed to get ulimit: {e}") + max_user_processes = 1024 + cpu_count = multiprocessing.cpu_count() // 2 + estimated_max_processes = max_user_processes // 4 + process_num = min(max(1, cpu_count), estimated_max_processes) + if process_num < cpu_count: + print_warn_log(f"Reducing number of processes to {process_num} due to system limits.") op_name_mapping_dict = read_dump_path(result_path) - op_names = [] - for _ in range(process_num): - op_names.append([]) + op_names = [[] for _ in range(process_num)] all_op_names = list(op_name_mapping_dict.keys()) for i, op_name in enumerate(all_op_names): op_names[i % process_num].append(op_name) all_tasks = [] - pool = multiprocessing.Pool(process_num) + try: + pool = multiprocessing.Pool(process_num) + except RuntimeError as e: + print_error_log(f"Failed to start process pool: {e}") + return def err_call(args): print_error_log('multiprocess compare failed! Reason: {}'.format(args)) try: pool.terminate() - if os.path.exists(result_path): - os.remove(result_path) except OSError as e: print_error_log("pool terminate failed") - - for process_idx, fusion_op_names in enumerate(op_names): - idx = [process_num, process_idx] - task = pool.apply_async(func, - args=(idx, fusion_op_names, op_name_mapping_dict, result_path, lock, input_parma), - error_callback=err_call) - all_tasks.append(task) + try: + for process_idx, fusion_op_names in enumerate(op_names): + idx = [process_num, process_idx] + task = pool.apply_async(func, + args=(idx, fusion_op_names, op_name_mapping_dict, result_path, lock, input_parma), + error_callback=err_call) + all_tasks.append(task) + except RuntimeError as e: + print_error_log(f"Error starting new tasks: {e}") + pool.terminate() + if os.path.exists(result_path): + os.remove(result_path) + return pool.close() pool.join()