diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 738455fd80ab9bf6b1698b656b4f531958bdd31d..5535709707e9ddb2317e7a81173b729a3526c6a4 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -82,7 +82,8 @@ class MsprobeBaseException(Exception): INVALID_API_NAME_ERROR = 36 CROSS_FRAME_ERROR = 37 MISSING_THRESHOLD_ERROR = 38 - WRONG_THRESHOLD_ERROR = 38 + WRONG_THRESHOLD_ERROR = 39 + MULTIPROCESS_ERROR = 40 def __init__(self, code, error_info: str = ""): super(MsprobeBaseException, self).__init__() @@ -271,6 +272,16 @@ def get_stack_construct_by_dump_json_path(dump_json_path): stack_json = os.path.join(directory, "stack.json") construct_json = os.path.join(directory, "construct.json") + stack_json_exist = os.path.exists(stack_json) + construct_json_exist = os.path.exists(construct_json) + + if not stack_json_exist and not construct_json_exist: + logger.info("stack.json and construct.json not found") + return {}, {} + if not stack_json_exist or not construct_json_exist: + logger.error("stack.json or construct.json not found, please check.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + stack = load_json(stack_json) construct = load_json(construct_json) return stack, construct diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 6e5aaa232df114b3554d61ac564a1a7072054feb..33c7e15f949cb7ae4cb3b79c958b010308d31629 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -26,7 +26,7 @@ from tqdm import tqdm from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.file_utils import save_workbook from msprobe.core.common.log import logger -from msprobe.core.common.utils import get_header_index +from msprobe.core.common.utils import get_header_index, CompareException from msprobe.core.compare.utils import table_value_is_valid, gen_api_batches from msprobe.core.compare.config import ModeConfig @@ -359,18 +359,25 @@ class HighLight: def err_call(args): logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args)) - try: - pool.close() - except OSError: - logger.error("Pool terminate failed") result_df_columns = result_df.columns.tolist() for column in result_df_columns: self.value_check(column) + async_results = [] for df_chunk in chunks: - pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) + result = pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) + async_results.append(result) pool.close() + + for ar in async_results: + try: + ar.get(timeout=3600) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() def df_malicious_value_check(self, result_df): diff --git a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py index 12b253e16a0bd68e5c8ab3ae13ec832aa374dcae..28cf28fded3d96578cdf5b6c2773c40360a8ac89 100644 --- a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py @@ -52,16 +52,20 @@ def _ms_graph_handle_multi_process(func, result_df, mode): def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError as e: - logger.error(f'pool terminate failed: {str(e)}') for df_chunk in df_chunks: result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call) results.append(result) - final_results = [r.get() for r in results] + pool.close() + + try: + final_results = [r.get(timeout=3600) for r in results] + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() return pd.concat(final_results, ignore_index=True) @@ -277,10 +281,6 @@ class CompareRealData: def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError: - logger.error("pool terminate failed") progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) @@ -298,7 +298,14 @@ class CompareRealData: ) results.append(result) - final_results = [r.get() for r in results] pool.close() + + try: + final_results = [r.get(timeout=3600) for r in results] + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() return pd.concat(final_results, ignore_index=True) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 2e9ffca318e265b4c5292999bd6be2bbe178e191..07821ca57ce16b33dfefafbc196472bf64cae42f 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -693,10 +693,6 @@ def get_sorted_ranks(npu_dump_dir, bench_dump_dir): def multi_statistics_compare(func, func_args): def err_call(args): logger.error(f'Multiprocess statistics compare failed! Reason: {args}') - try: - pool.close() - except OSError: - logger.error("Pool terminate failed") compare_func, input_param_nr_list, output_path, kwargs = func_args @@ -713,9 +709,22 @@ def multi_statistics_compare(func, func_args): chunks[i].append(input_param_nr_list[param_num - remainder + i]) pool = multiprocessing.Pool(process_num) + + async_results = [] for chunk in chunks: - pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call) + result = pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call) + async_results.append(result) + pool.close() + + for ar in async_results: + try: + ar.get(timeout=3600) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + pool.terminate() + raise CompareException(CompareException.MULTIPROCESS_ERROR) from e + pool.join() diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py index 4b509cc3327b40bba058fe3eaac20dedb89d5094..e5a12da97e950988fae130b5ddf40791a294e837 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py @@ -29,7 +29,7 @@ from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.core.common.exceptions import FileCheckException from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \ - check_path_before_create, load_npy + check_path_before_create, load_npy from msprobe.core.common.const import CompareConst from msprobe.core.compare.npy_compare import compare_ops_apply from msprobe.core.compare.multiprocessing_compute import check_accuracy @@ -39,11 +39,11 @@ from msprobe.mindspore.compare.utils import check_name_map_dict def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]: """ 高级目录比对函数,完全镜像输入目录结构 - + Args: input_params: 包含npu_path和bench_path的字典 output_dir: 输出根目录 - + Returns: 当输入目录是平铺npy文件时返回DataFrame,否则返回None """ @@ -69,29 +69,29 @@ def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataF def process_directory_pair(item: Tuple[Path, Tuple[Path, Path]], name_map_dict: Dict, output_dir: str): """ 处理一个目录对 - + Args: item: (相对路径, (npu目录, bench目录))元组 output_dir: 输出根目录 - + Returns: 比对结果的DataFrame(仅平铺结构时返回) """ rel_path, (npu_dir, bench_dir) = item - + # 创建镜像输出目录 output_path = Path(output_dir) / rel_path create_directory(output_path) - + # 生成文件映射 npu_files = find_npy_files(npu_dir) bench_files = find_npy_files(bench_dir) map_dict = generate_map_dict(npu_files, bench_files, name_map_dict) - + if not map_dict: logger.warning(f"No file pairs found in {rel_path}") return None - + # 执行比对 result_df = do_multi_process(process_chunk, map_dict) check_path_before_create(output_path) @@ -105,16 +105,16 @@ def process_directory_pair(item: Tuple[Path, Tuple[Path, Path]], name_map_dict: def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple[Path, Path]]: """ 构建镜像文件树,键为相对路径,值为(npu_path, bench_path)元组 - + Args: npu_root: NPU数据根目录 bench_root: 基准数据根目录 - + Returns: 文件树字典 """ file_tree = {} - + # 遍历NPU目录构建树结构 # 使用os.walk遍历目录,限制深度为10层 for root, dirs, files in os.walk(npu_root): @@ -123,23 +123,23 @@ def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple if depth > 10: dirs.clear() # 清空dirs列表以阻止继续递归 continue - + # 检查当前目录下是否有npy文件 if any(f.endswith('.npy') for f in files): # 获取相对路径 dir_path = Path(root).relative_to(npu_root) npu_dir_pair = os.path.join(npu_root, dir_path) bench_dir_pair = os.path.join(bench_root, dir_path) - + try: check_file_or_directory_path(bench_dir_pair, isdir=True) except FileCheckException: continue - + # 添加到文件树 if dir_path not in file_tree: file_tree[dir_path] = (npu_dir_pair, bench_dir_pair) - + return file_tree @@ -162,13 +162,13 @@ def find_npy_files(directory): file_name = base_name[0] logger.info(f"Generating file info for file: {file}") - + # 使用一致的分割逻辑 file_ele = file_name.split('_') - + if len(file_ele) < 2: continue - + key = '_'.join(file_ele[:-2]) if key: # 文件的完整路径 @@ -212,14 +212,14 @@ def do_multi_process(func, map_dict): df_chunks = [result_df] process_num = 1 logger.info(f"Using {process_num} processes with chunk size {df_chunk_size}") - + # 分割字典 map_chunks = split_dict(map_dict, df_chunk_size) - + # 创建结果列表和进程池 results = [] pool = multiprocessing.Pool(process_num) - + progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) def update_progress(size, progress_lock, extra_param=None): @@ -228,35 +228,30 @@ def do_multi_process(func, map_dict): def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.close() - except OSError as e: - logger.error(f'pool terminate failed: {str(e)}') results = [] + + # 提交任务到进程池 + for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)): + start_idx = df_chunk_size * process_idx + result = pool.apply_async( + func, + args=(df_chunk, start_idx, map_chunk, lock), + error_callback=err_call, + callback=partial(update_progress, len(map_chunk), lock) + ) + results.append(result) + pool.close() + try: - # 提交任务到进程池 - for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)): - start_idx = df_chunk_size * process_idx - result = pool.apply_async( - func, - args=(df_chunk, start_idx, map_chunk, lock), - error_callback=err_call, - callback=partial(update_progress, len(map_chunk), lock) - ) - results.append(result) - - final_results = [r.get() for r in results] - # 等待所有任务完成 - pool.close() - pool.join() - return pd.concat(final_results, ignore_index=True) + final_results = [r.get(timeout=3600) for r in results] except Exception as e: - logger.error(f"\nMain process error: {str(e)}") + logger.error(f"Task failed with exception: {e}") pool.terminate() return pd.DataFrame({}) - finally: - pool.close() + # 等待所有任务完成 + pool.join() + return pd.concat(final_results, ignore_index=True) def initialize_result_df(total_size): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py index 16473ff386d89de5f3bbb269e69837c07a950ea5..4d58f447d1b30194a74e7aebb38eed8e876a76c7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py @@ -35,7 +35,8 @@ def read_pt_data(dir_path, file_name): data_value = load_pt(data_path, to_cpu=True).detach() except RuntimeError as e: # 这里捕获 load_pt 中抛出的异常 - logger.error(f"Failed to load the .pt file at {data_path}.") + data_path_file_name = os.path.basename(data_path) + logger.error(f"Failed to load the .pt file at {data_path_file_name}.") raise CompareException(CompareException.INVALID_FILE_ERROR) from e except AttributeError as e: # 这里捕获 detach 方法抛出的异常