From 6c9b5979a4687640b24a457f99be6ac15df4012b Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Mon, 30 Jun 2025 16:54:04 +0800 Subject: [PATCH] compare indexerror bugfix --- .../msprobe/core/compare/acc_compare.py | 55 +++++++++++-------- .../msprobe/core/compare/utils.py | 6 ++ 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 2697566ed3..59c4b42ee2 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -31,7 +31,7 @@ from msprobe.core.common.utils import CompareException, add_time_with_xlsx, chec set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, add_time_with_json from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \ - reorder_op_x_list, set_stack_json_path + reorder_op_x_list, set_stack_json_path, check_api_info_len from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict from msprobe.core.compare.multiprocessing_compute import CompareRealData from msprobe.core.compare.highlight import HighLight @@ -214,29 +214,38 @@ class ParseData: # 遍历单个API的所有item for index, op_name in enumerate(op_name_reorder): result[CompareConst.OP_NAME].append(op_name) - try: - if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name): - struct = merge_list[CompareConst.INPUT_STRUCT].pop(0) - elif CompareConst.OUTPUT_PATTERN in op_name: - struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0) - elif CompareConst.PARAMS_PATTERN in op_name: - struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0) - else: - struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0) - result[Const.DTYPE].append(struct[0]) - result[Const.SHAPE].append(struct[1]) - if self.mode_config.dump_mode == Const.MD5: - result[Const.MD5].append(struct[2]) - result[Const.SUMMARY].append(summary_reorder.pop(0)) - result[Const.STACK_INFO].append( - merge_list[Const.STACK_INFO][0] if index == 0 and self.mode_config.stack_mode else None) - if self.mode_config.dump_mode == Const.ALL: - result['data_name'].append(data_name_reorder.pop(0)) - except IndexError as e: - logger.error(f'Index out of bounds error, please check info of api: {op_name}.') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - + if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name): + info_list = merge_list[CompareConst.INPUT_STRUCT] + elif CompareConst.OUTPUT_PATTERN in op_name: + info_list = merge_list[CompareConst.OUTPUT_STRUCT] + elif CompareConst.PARAMS_PATTERN in op_name: + info_list = merge_list[CompareConst.PARAMS_STRUCT] + elif CompareConst.PARAMS_GRAD_PATTERN in op_name: + info_list = merge_list[CompareConst.PARAMS_GRAD_STRUCT] + else: + info_list = merge_list[CompareConst.DEBUG_STRUCT] + check_api_info_len(op_name, info_list, 1) + struct = info_list.pop(0) + + check_api_info_len(op_name, struct, 2) + result[Const.DTYPE].append(struct[0]) + result[Const.SHAPE].append(struct[1]) + if self.mode_config.dump_mode == Const.MD5: + check_api_info_len(op_name, struct, 3) + result[Const.MD5].append(struct[2]) + + check_api_info_len(op_name, summary_reorder, 1) + result[Const.SUMMARY].append(summary_reorder.pop(0)) + + if index == 0 and self.mode_config.stack_mode: + check_api_info_len(op_name, merge_list[Const.STACK_INFO], 1) + result[Const.STACK_INFO].append(merge_list[Const.STACK_INFO][0]) + else: + result[Const.STACK_INFO].append(None) + if self.mode_config.dump_mode == Const.ALL: + check_api_info_len(op_name, data_name_reorder, 1) + result['data_name'].append(data_name_reorder.pop(0)) progress_bar.update(1) progress_bar.close() return pd.DataFrame(result) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 6da9f3e4bd..809c17d603 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -225,6 +225,12 @@ def merge_tensor(tensor_list, dump_mode): return op_dict if op_dict[CompareConst.OP_NAME] else {} +def check_api_info_len(op_name, info_list, len_require): + if len(info_list) < len_require: + logger.error(f'Index out of bounds error, please check info of api: {op_name}.') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) + + def print_compare_ends_info(): total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS logger.info('*' * total_len) -- Gitee