From 88df5dfb8369bb770402e5186ef205953308048c Mon Sep 17 00:00:00 2001 From: zhujiaxing Date: Tue, 8 Oct 2024 15:13:28 +0800 Subject: [PATCH] Add a check in get_result_md5_compare function that may raise an out-of-index error. --- .../msprobe/core/compare/acc_compare.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 77f22b80c..bc3741668 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -26,20 +26,23 @@ class Comparator: @staticmethod def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args): - result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0], - bench_ops_all.get(bench_op_name).get('struct')[0], - npu_ops_all.get(ms_op_name).get('struct')[1], - bench_ops_all.get(bench_op_name).get('struct')[1], - npu_ops_all.get(ms_op_name).get('struct')[2], - bench_ops_all.get(bench_op_name).get('struct')[2], - CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2] - == bench_ops_all.get(bench_op_name).get('struct')[2] - else CompareConst.DIFF] - if args[0]: - result_item.extend(args[1]) + npu_struct = npu_ops_all.get(ms_op_name).get('struct', []) + bench_struct = bench_ops_all.get(bench_op_name).get('struct', []) + try: + + result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0], + npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2], + CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF] + except IndexError as err: + logger.error(f"The length of npu_struct or bench_struct must be >= 3, " + f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!") + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from err else: - result_item.append(CompareConst.NONE) - return result_item + if args[0]: + result_item.extend(args[1]) + else: + result_item.append(CompareConst.NONE) + return result_item @staticmethod def calculate_summary_data(npu_summary_data, bench_summary_data, result_item): -- Gitee