diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 77f22b80c3f65e928eedbea452086695f88827c6..bc3741668308af9da18d48d0f51b3e9c784234cf 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -26,20 +26,23 @@ class Comparator: @staticmethod def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args): - result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0], - bench_ops_all.get(bench_op_name).get('struct')[0], - npu_ops_all.get(ms_op_name).get('struct')[1], - bench_ops_all.get(bench_op_name).get('struct')[1], - npu_ops_all.get(ms_op_name).get('struct')[2], - bench_ops_all.get(bench_op_name).get('struct')[2], - CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2] - == bench_ops_all.get(bench_op_name).get('struct')[2] - else CompareConst.DIFF] - if args[0]: - result_item.extend(args[1]) + npu_struct = npu_ops_all.get(ms_op_name).get('struct', []) + bench_struct = bench_ops_all.get(bench_op_name).get('struct', []) + try: + + result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0], + npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2], + CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF] + except IndexError as err: + logger.error(f"The length of npu_struct or bench_struct must be >= 3, " + f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!") + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from err else: - result_item.append(CompareConst.NONE) - return result_item + if args[0]: + result_item.extend(args[1]) + else: + result_item.append(CompareConst.NONE) + return result_item @staticmethod def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):