From ba293a1410f5854a4393b48d7eec0aa10b0928dd Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Thu, 4 Sep 2025 10:20:02 +0800 Subject: [PATCH] compare safety bugfix --- .../msprobe/core/common/utils.py | 13 +++++- .../msprobe/core/compare/highlight.py | 19 ++++++--- .../core/compare/multiprocessing_compute.py | 27 +++++++----- .../msprobe/core/compare/utils.py | 19 ++++++--- .../mindspore/compare/common_dir_compare.py | 42 +++++++++---------- .../msprobe/pytorch/compare/utils.py | 3 +- 6 files changed, 77 insertions(+), 46 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 83455b8e3..31c6935ca 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -83,7 +83,8 @@ class MsprobeBaseException(Exception): INVALID_API_NAME_ERROR = 36 CROSS_FRAME_ERROR = 37 MISSING_THRESHOLD_ERROR = 38 - WRONG_THRESHOLD_ERROR = 38 + WRONG_THRESHOLD_ERROR = 39 + MULTIPROCESS_ERROR = 40 def __init__(self, code, error_info: str = ""): super(MsprobeBaseException, self).__init__() @@ -348,6 +349,16 @@ def get_stack_construct_by_dump_json_path(dump_json_path): stack_json = os.path.join(directory, "stack.json") construct_json = os.path.join(directory, "construct.json") + stack_json_exist = os.path.exists(stack_json) + construct_json_exist = os.path.exists(construct_json) + + if not stack_json_exist and not construct_json_exist: + logger.info("stack.json and construct.json not found") + return {}, {} + if not stack_json_exist or not construct_json_exist: + logger.error("stack.json or construct.json not found, please check.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + stack = load_json(stack_json) construct = load_json(construct_json) return stack, construct diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 6e5aaa232..33c7e15f9 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -26,7 +26,7 @@ from tqdm import tqdm from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.file_utils import save_workbook from msprobe.core.common.log import logger -from msprobe.core.common.utils import get_header_index +from msprobe.core.common.utils import get_header_index, CompareException from msprobe.core.compare.utils import table_value_is_valid, gen_api_batches from msprobe.core.compare.config import ModeConfig @@ -359,18 +359,25 @@ class HighLight: def err_call(args): logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args)) - try: - pool.close() - except OSError: - logger.error("Pool terminate failed") result_df_columns = result_df.columns.tolist() for column in result_df_columns: self.value_check(column) + async_results = [] for df_chunk in chunks: - pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) + result = pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) + async_results.append(result) pool.close() + + for ar in async_results: + try: + ar.get(timeout=3600) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() def df_malicious_value_check(self, result_df): diff --git a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py index 12b253e16..28cf28fde 100644 --- a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py @@ -52,16 +52,20 @@ def _ms_graph_handle_multi_process(func, result_df, mode): def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError as e: - logger.error(f'pool terminate failed: {str(e)}') for df_chunk in df_chunks: result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call) results.append(result) - final_results = [r.get() for r in results] + pool.close() + + try: + final_results = [r.get(timeout=3600) for r in results] + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() return pd.concat(final_results, ignore_index=True) @@ -277,10 +281,6 @@ class CompareRealData: def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError: - logger.error("pool terminate failed") progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) @@ -298,7 +298,14 @@ class CompareRealData: ) results.append(result) - final_results = [r.get() for r in results] pool.close() + + try: + final_results = [r.get(timeout=3600) for r in results] + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() return pd.concat(final_results, ignore_index=True) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 5503bdd3f..2493442d6 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -695,10 +695,6 @@ def get_sorted_ranks(npu_dump_dir, bench_dump_dir): def multi_statistics_compare(func, func_args): def err_call(args): logger.error(f'Multiprocess statistics compare failed! Reason: {args}') - try: - pool.close() - except OSError: - logger.error("Pool terminate failed") compare_func, input_param_nr_list, output_path, kwargs = func_args @@ -715,9 +711,22 @@ def multi_statistics_compare(func, func_args): chunks[i].append(input_param_nr_list[param_num - remainder + i]) pool = multiprocessing.Pool(process_num) + + async_results = [] for chunk in chunks: - pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call) + result = pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call) + async_results.append(result) + pool.close() + + for ar in async_results: + try: + ar.get(timeout=3600) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py index a973f8e08..af7a78177 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py @@ -228,34 +228,30 @@ def do_multi_process(func, map_dict): def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError as e: - logger.error(f'pool terminate failed: {str(e)}') + results = [] + + # 提交任务到进程池 + for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)): + start_idx = df_chunk_size * process_idx + result = pool.apply_async( + func, + args=(df_chunk, start_idx, map_chunk, lock), + error_callback=err_call, + callback=partial(update_progress, len(map_chunk), lock) + ) + results.append(result) + pool.close() + try: - # 提交任务到进程池 - for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)): - start_idx = df_chunk_size * process_idx - result = pool.apply_async( - func, - args=(df_chunk, start_idx, map_chunk, lock), - error_callback=err_call, - callback=partial(update_progress, len(map_chunk), lock) - ) - results.append(result) - - final_results = [r.get() for r in results] - # 等待所有任务完成 - pool.close() - pool.join() - return pd.concat(final_results, ignore_index=True) + final_results = [r.get(timeout=3600) for r in results] except Exception as e: - logger.error(f"\nMain process error: {str(e)}") + logger.error(f"Task failed with exception: {e}") pool.terminate() return pd.DataFrame({}) - finally: - pool.close() + # 等待所有任务完成 + pool.join() + return pd.concat(final_results, ignore_index=True) def initialize_result_df(total_size): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py index 16473ff38..4d58f447d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py @@ -35,7 +35,8 @@ def read_pt_data(dir_path, file_name): data_value = load_pt(data_path, to_cpu=True).detach() except RuntimeError as e: # 这里捕获 load_pt 中抛出的异常 - logger.error(f"Failed to load the .pt file at {data_path}.") + data_path_file_name = os.path.basename(data_path) + logger.error(f"Failed to load the .pt file at {data_path_file_name}.") raise CompareException(CompareException.INVALID_FILE_ERROR) from e except AttributeError as e: # 这里捕获 detach 方法抛出的异常 -- Gitee