diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 5645cba6919afb36c56805b6dcce614bf7969ff2..97b2683dac88df070d1fbcd8596662ac5a8d458f 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -82,6 +82,9 @@ class MsprobeBaseException(Exception): INVALID_STATE_ERROR = 35 INVALID_API_NAME_ERROR = 36 CROSS_FRAME_ERROR = 37 + MISSING_THRESHOLD_ERROR = 38 + WRONG_THRESHOLD_ERROR = 39 + MULTIPROCESS_ERROR = 40 def __init__(self, code, error_info: str = ""): super(MsprobeBaseException, self).__init__() @@ -351,6 +354,16 @@ def get_stack_construct_by_dump_json_path(dump_json_path): stack_json = os.path.join(directory, "stack.json") construct_json = os.path.join(directory, "construct.json") + stack_json_exist = os.path.exists(stack_json) + construct_json_exist = os.path.exists(construct_json) + + if not stack_json_exist and not construct_json_exist: + logger.info("stack.json and construct.json not found") + return {}, {} + if not stack_json_exist or not construct_json_exist: + logger.error("stack.json or construct.json not found, please check.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + stack = load_json(stack_json) construct = load_json(construct_json) return stack, construct @@ -700,4 +713,3 @@ def check_process_num(process_num): raise ValueError(f"process_num({process_num}) is not a positive integer") if process_num > Const.MAX_PROCESS_NUM: raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.") - diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 71959d77d1ad3f3e293b103c6844d9641c9e51be..03218fac320cec8cc93062f50b3fb7b4e587fd85 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -396,18 +396,25 @@ class HighLight: def err_call(args): logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args)) - try: - pool.close() - except OSError: - logger.error("Pool terminate failed") result_df_columns = result_df.columns.tolist() for column in result_df_columns: self.value_check(column) + async_results = [] for df_chunk in chunks: - pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) + result = pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) + async_results.append(result) pool.close() + + for ar in async_results: + try: + ar.get(timeout=3600) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() def df_malicious_value_check(self, df_chunk, result_df_columns): diff --git a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py index 8e23f9f8b9f0f1cb33e58013aaa325da5ffdf11b..d4e7d1975663ce45d4750740a989f459863fe96f 100644 --- a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py @@ -52,16 +52,20 @@ def _ms_graph_handle_multi_process(func, result_df, mode): def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError as e: - logger.error(f'pool terminate failed: {str(e)}') for df_chunk in df_chunks: result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call) results.append(result) - final_results = [r.get() for r in results] + pool.close() + + try: + final_results = [r.get(timeout=3600) for r in results] + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() return pd.concat(final_results, ignore_index=True) @@ -267,10 +271,6 @@ class CompareRealData: def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError: - logger.error("pool terminate failed") progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) @@ -288,7 +288,14 @@ class CompareRealData: ) results.append(result) - final_results = [r.get() for r in results] pool.close() + + try: + final_results = [r.get(timeout=3600) for r in results] + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() return pd.concat(final_results, ignore_index=True) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py index 7371b12200d4628bf508c5abbcc1f997ddf9846e..3ee32895072f77b1ec03627c5c1b36ee5352350d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py @@ -216,34 +216,30 @@ def do_multi_process(func, map_dict): def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError as e: - logger.error(f'pool terminate failed: {str(e)}') + results = [] + + # 提交任务到进程池 + for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)): + start_idx = df_chunk_size * process_idx + result = pool.apply_async( + func, + args=(df_chunk, start_idx, map_chunk, lock), + error_callback=err_call, + callback=partial(update_progress, len(map_chunk), lock) + ) + results.append(result) + pool.close() + try: - # 提交任务到进程池 - for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)): - start_idx = df_chunk_size * process_idx - result = pool.apply_async( - func, - args=(df_chunk, start_idx, map_chunk, lock), - error_callback=err_call, - callback=partial(update_progress, len(map_chunk), lock) - ) - results.append(result) - - final_results = [r.get() for r in results] - # 等待所有任务完成 - pool.close() - pool.join() - return pd.concat(final_results, ignore_index=True) + final_results = [r.get(timeout=3600) for r in results] except Exception as e: - logger.error(f"\nMain process error: {str(e)}") + logger.error(f"Task failed with exception: {e}") pool.terminate() return pd.DataFrame({}) - finally: - pool.close() + # 等待所有任务完成 + pool.join() + return pd.concat(final_results, ignore_index=True) def initialize_result_df(total_size): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py index 16473ff386d89de5f3bbb269e69837c07a950ea5..4d58f447d1b30194a74e7aebb38eed8e876a76c7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py @@ -35,7 +35,8 @@ def read_pt_data(dir_path, file_name): data_value = load_pt(data_path, to_cpu=True).detach() except RuntimeError as e: # 这里捕获 load_pt 中抛出的异常 - logger.error(f"Failed to load the .pt file at {data_path}.") + data_path_file_name = os.path.basename(data_path) + logger.error(f"Failed to load the .pt file at {data_path_file_name}.") raise CompareException(CompareException.INVALID_FILE_ERROR) from e except AttributeError as e: # 这里捕获 detach 方法抛出的异常