From 0edc46457db85abfa23cd7a23454b18248b9edb7 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Mon, 22 Jul 2024 15:31:18 +0800 Subject: [PATCH 1/7] md5 compare highlight bugfix --- .../msprobe/core/common/utils.py | 4 ++- .../msprobe/pytorch/compare/acc_compare.py | 25 ++++++++++++++--- .../msprobe/pytorch/compare/highlight.py | 27 ++++++++++++++----- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 32aba8d8af..8be2fe0090 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -501,9 +501,11 @@ def task_dumppath_get(input_param): return summary_compare, md5_compare -def get_header_index(header_name, summary_compare=False): +def get_header_index(header_name, summary_compare=False, md5_compare=False): if summary_compare: header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] + elif md5_compare: + header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] else: header = CompareConst.COMPARE_RESULT_HEADER[:] if header_name not in header: diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index a4b6884343..af87ea49e2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -526,10 +526,27 @@ def handle_inf_nan(n_value, b_value): return n_value, b_value -def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False): +def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_select): """找到单个API中需要高亮的行""" + red_lines, yellow_lines = [], [] + LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) + ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer']) + ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) + color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) + + summary_compare = summary_md5_select[0] + md5_compare = summary_md5_select[1] + if md5_compare: + for i, line in enumerate(result): + num = last_len + i + line_info = LineInfo(line_data=line, num_pointer=num) + for rule in HighlightRules.md5_compare_rules.values(): + rule.apply(line_info, color_columns, summary_compare, md5_compare) + highlight_dict.get('red_rows', []).extend(list(set(red_lines))) + highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines))) return + npu_max_index = get_header_index('NPU max', summary_compare) bench_max_index = get_header_index('Bench max', summary_compare) max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) @@ -587,6 +604,7 @@ def get_name_and_state(name): def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare): """将dataframe根据API分组,并找到有误差的算子用于高亮""" + summary_md5_select = [summary_compare, md5_compare] result = result_df.values start, input_num, output_num, end = 0, 0, 0, len(result_df) last_api_name, last_state = None, None @@ -603,7 +621,7 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m else: output_num = num find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, - summary_compare, md5_compare) + summary_md5_select) num, last_api_name, last_state = 1, api_name, state start += input_num + output_num input_num, output_num = 1, 0 @@ -614,7 +632,8 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m input_num = num else: output_num = num - find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare) + find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, + summary_md5_select) def highlight_rows_xlsx(result_df, highlight_dict, file_path): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py index 82f0022f8b..2543c136fe 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py @@ -7,13 +7,13 @@ from msprobe.core.common.const import CompareConst class HighlightCheck(abc.ABC): @abc.abstractmethod - def apply(self, info, color_columns, summary_compare): + def apply(self, info, color_columns, summary_compare, md5_compare=False): raise NotImplementedError class CheckOrderMagnitude(HighlightCheck): """检查Max diff的数量级差异""" - def apply(self, info, color_columns, summary_compare=True): + def apply(self, info, color_columns, summary_compare, md5_compare=False): api_in, api_out, num = info max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]): @@ -26,7 +26,7 @@ class CheckOrderMagnitude(HighlightCheck): class CheckOneThousandErrorRatio(HighlightCheck): """检查千分误差比率""" - def apply(self, info, color_columns, summary_compare=True): + def apply(self, info, color_columns, summary_compare, md5_compare=False): api_in, api_out, num = info one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare) if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)): @@ -39,7 +39,7 @@ class CheckOneThousandErrorRatio(HighlightCheck): class CheckCosineSimilarity(HighlightCheck): """检查余弦相似度""" - def apply(self, info, color_columns, summary_compare=True): + def apply(self, info, color_columns, summary_compare, md5_compare=False): api_in, api_out, num = info cosine_index = get_header_index('Cosine', summary_compare) if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)): @@ -50,7 +50,7 @@ class CheckCosineSimilarity(HighlightCheck): class CheckMaxRelativeDiff(HighlightCheck): """检查最大相对差异""" - def apply(self, info, color_columns, summary_compare=True): + def apply(self, info, color_columns, summary_compare, md5_compare=False): api_in, api_out, num = info max_diff_index = get_header_index('Max diff', summary_compare) bench_max_index = get_header_index('Bench max', summary_compare) @@ -65,9 +65,20 @@ class CheckMaxRelativeDiff(HighlightCheck): color_columns.yellow.append(num) +class CheckMd5Diff(HighlightCheck): + """检查md5值差异""" + def apply(self, info, color_columns, summary_compare, md5_compare=False): + line, num = info + npu_md5_index = get_header_index('NPU MD5', summary_compare, md5_compare) + bench_md5_index = get_header_index('BENCH MD5', summary_compare, md5_compare) + if str(line[npu_md5_index]) != str(line[bench_md5_index]): + color_columns.red.append(num) + return + + class CheckOverflow(HighlightCheck): """检查是否存在溢出""" - def apply(self, info, color_columns, summary_compare=True): + def apply(self, info, color_columns, summary_compare, md5_compare=False): line, num = info npu_max_index = get_header_index('NPU max', summary_compare) npu_min_index = get_header_index('NPU min', summary_compare) @@ -98,3 +109,7 @@ class HighlightRules: "check_order_magnitude": CheckOrderMagnitude(), "check_max_relative_diff": CheckMaxRelativeDiff(), } + + md5_compare_rules = { + "check_md5_diff": CheckMd5Diff() + } -- Gitee From 27a72da8efb0c41d93f6a94dddb9550183fc8410 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Mon, 22 Jul 2024 15:51:58 +0800 Subject: [PATCH 2/7] md5 compare highlight bugfix --- .../accuracy_tools/msprobe/pytorch/compare/acc_compare.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index af87ea49e2..1f991934a7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -537,6 +537,7 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_s summary_compare = summary_md5_select[0] md5_compare = summary_md5_select[1] + # 对单行API的输入或输出进行md5值差异判断 if md5_compare: for i, line in enumerate(result): num = last_len + i @@ -551,12 +552,6 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_s bench_max_index = get_header_index('Bench max', summary_compare) max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) - red_lines, yellow_lines = [], [] - LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) - ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer']) - ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) - color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - # 对单行API的输入或输出进行误差判断 for i, line in enumerate(result): num = last_len + i -- Gitee From 3ccd72d15380c226e6d9eeeea69c147a1e798ca7 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Tue, 23 Jul 2024 15:21:39 +0800 Subject: [PATCH 3/7] review fix add compare_mode --- .../msprobe/core/common/const.py | 6 + .../msprobe/core/common/utils.py | 23 ++- .../msprobe/pytorch/compare/acc_compare.py | 151 ++++++++++++------ .../msprobe/pytorch/compare/highlight.py | 53 +++--- 4 files changed, 156 insertions(+), 77 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index dea829c3ff..61d4cb5b52 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -140,6 +140,12 @@ class CompareConst: NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT ] + HEAD_OF_COMPARE_MODE = { + "summary": CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:], + "md5": CompareConst.MD5_COMPARE_RESULT_HEADER[:], + "all": CompareConst.COMPARE_RESULT_HEADER[:] + } + # compare standard HUNDRED_RATIO_THRESHOLD = 0.01 THOUSAND_RATIO_THRESHOLD = 0.001 diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 8be2fe0090..6d98165f51 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -501,13 +501,22 @@ def task_dumppath_get(input_param): return summary_compare, md5_compare -def get_header_index(header_name, summary_compare=False, md5_compare=False): - if summary_compare: - header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] - elif md5_compare: - header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] - else: - header = CompareConst.COMPARE_RESULT_HEADER[:] +# HEAD_OF_COMPARE_MODE = { +# "summary": CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:], +# "md5": CompareConst.MD5_COMPARE_RESULT_HEADER[:], +# "all": CompareConst.COMPARE_RESULT_HEADER[:] +# } + + +# def get_header_index(header_name, summary_compare=False, md5_compare=False): +def get_header_index(header_name, compare_mode): + header = CompareConst.HEAD_OF_COMPARE_MODE.get(compare_mode) + # if summary_compare: + # header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] + # elif md5_compare: + # header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] + # else: + # header = CompareConst.COMPARE_RESULT_HEADER[:] if header_name not in header: logger.error(f"{header_name} not in data name") raise CompareException(CompareException.INVALID_PARAM_ERROR) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index 1f991934a7..b011cacc19 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -130,7 +130,8 @@ def rename_api(npu_name, process): return torch_func -def merge_tensor(tensor_list, summary_compare, md5_compare): +# def merge_tensor(tensor_list, summary_compare, md5_compare): +def merge_tensor(tensor_list, compare_mode): op_dict = {} op_dict["op_name"] = [] op_dict["input_struct"] = [] @@ -139,8 +140,9 @@ def merge_tensor(tensor_list, summary_compare, md5_compare): op_dict["summary"] = [] op_dict["stack_info"] = [] - all_mode_bool = not (summary_compare or md5_compare) - if all_mode_bool: + # all_mode_bool = not (summary_compare or md5_compare) + # if all_mode_bool: + if compare_mode == Const.ALL: op_dict["data_name"] = [] for tensor in tensor_list: @@ -148,7 +150,8 @@ def merge_tensor(tensor_list, summary_compare, md5_compare): op_dict['stack_info'].append(tensor['full_info']) break op_dict["op_name"].append(tensor['full_op_name']) - if not md5_compare: + # if not md5_compare: + if compare_mode != Const.MD5: if tensor['full_op_name'].find("input") != -1: op_dict["input_struct"].append((tensor['dtype'], tensor['shape'])) elif tensor['full_op_name'].find("kwarg") != -1: @@ -185,15 +188,17 @@ def match_op(npu_queue, bench_queue, fuzzy_match): return -1, -1 -def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False): +# def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False): +def get_accuracy(result, n_dict, b_dict, compare_mode): def get_accuracy_core(n_start, n_len, b_start, b_len, key): min_len = min(n_len, b_len) npu_stack_info = n_dict.get("stack_info", None) bench_stack_info = b_dict.get("stack_info", None) has_stack = npu_stack_info and bench_stack_info - all_mode_bool = not (summary_compare or md5_compare) - if all_mode_bool: + # all_mode_bool = not (summary_compare or md5_compare) + # if all_mode_bool: + if compare_mode == Const.ALL: npu_data_name = n_dict.get("data_name", None) bench_data_name = b_dict.get("data_name", None) @@ -204,7 +209,8 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals n_struct = n_dict[key][index] b_struct = b_dict[key][index] err_msg = "" - if md5_compare: + # if md5_compare: + if compare_mode == Const.MD5: result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], n_struct[2], b_struct[2], CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF] @@ -215,7 +221,8 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals result.append(result_item) continue - if summary_compare: + # if summary_compare: + if compare_mode == Const.SUMMARY: result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], " ", " ", " ", " ", " ", " ", " ", " "] else: @@ -227,7 +234,8 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals bench_summary_data = b_dict.get("summary")[b_start + index] result_item.extend(bench_summary_data) - if summary_compare: + # if summary_compare: + if compare_mode == Const.SUMMARY: start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF) warning_flag = False for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)): @@ -250,7 +258,8 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals if str(result_item[i]) in ('inf', '-inf', 'nan'): result_item[i] = f'{result_item[i]}\t' - result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES) + # result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES) + result_item.append(accuracy_check if compare_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES) result_item.append(err_msg) if has_stack and index == 0 and key == "input_struct": result_item.extend(npu_stack_info) @@ -265,7 +274,8 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals for index in range(b_len, n_len): n_name = n_dict['op_name'][n_start + index] n_struct = n_dict[key][index] - if md5_compare: + # if md5_compare: + if compare_mode == Const.MD5: result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN] result.append(result_item) @@ -526,7 +536,8 @@ def handle_inf_nan(n_value, b_value): return n_value, b_value -def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_select): +# def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_select): +def find_error_rows(result, last_len, n_num_input, highlight_dict, compare_mode): """找到单个API中需要高亮的行""" red_lines, yellow_lines = [], [] LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) @@ -534,11 +545,12 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_s ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - summary_compare = summary_md5_select[0] - md5_compare = summary_md5_select[1] + # summary_compare = summary_md5_select[0] + # md5_compare = summary_md5_select[1] # 对单行API的输入或输出进行md5值差异判断 - if md5_compare: + # if md5_compare: + if compare_mode == Const.MD5: for i, line in enumerate(result): num = last_len + i line_info = LineInfo(line_data=line, num_pointer=num) @@ -548,16 +560,20 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_s highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines))) return - npu_max_index = get_header_index('NPU max', summary_compare) - bench_max_index = get_header_index('Bench max', summary_compare) - max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) + # npu_max_index = get_header_index('NPU max', summary_compare) + # bench_max_index = get_header_index('Bench max', summary_compare) + # max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) + npu_max_index = get_header_index('NPU max', compare_mode) + bench_max_index = get_header_index('Bench max', compare_mode) + max_diff_index = get_header_index('Max diff' if compare_mode == Const.SUMMARY else 'MaxAbsErr', compare_mode) # 对单行API的输入或输出进行误差判断 for i, line in enumerate(result): num = last_len + i line_info = LineInfo(line_data=line, num_pointer=num) for rule in HighlightRules.basic_rules.values(): - rule.apply(line_info, color_columns, summary_compare) + # rule.apply(line_info, color_columns, summary_compare) + rule.apply(line_info, color_columns, compare_mode) # 对API的输出与输入比较,进行误差判断 for n, api_out in enumerate(result[n_num_input:len(result)]): @@ -575,12 +591,15 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_s continue api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num) - if summary_compare: + # if summary_compare: + if compare_mode == Const.SUMMARY: for rule in HighlightRules.summary_compare_rules.values(): - rule.apply(api_info, color_columns, summary_compare) + # rule.apply(api_info, color_columns, summary_compare) + rule.apply(api_info, color_columns, compare_mode) else: for rule in HighlightRules.compare_rules.values(): - rule.apply(api_info, color_columns, summary_compare) + # rule.apply(api_info, color_columns, summary_compare) + rule.apply(api_info, color_columns, compare_mode) highlight_dict.get('red_rows', []).extend(list(set(red_lines))) highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines))) @@ -597,9 +616,10 @@ def get_name_and_state(name): return api_name, state -def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare): +# def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare): +def find_compare_result_error_rows(result_df, highlight_dict, compare_mode): """将dataframe根据API分组,并找到有误差的算子用于高亮""" - summary_md5_select = [summary_compare, md5_compare] + # summary_md5_select = [summary_compare, md5_compare] result = result_df.values start, input_num, output_num, end = 0, 0, 0, len(result_df) last_api_name, last_state = None, None @@ -615,8 +635,10 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m num, last_state = 1, state else: output_num = num + # find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, + # summary_md5_select) find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, - summary_md5_select) + compare_mode) num, last_api_name, last_state = 1, api_name, state start += input_num + output_num input_num, output_num = 1, 0 @@ -627,8 +649,10 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m input_num = num else: output_num = num + # find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, + # summary_md5_select) find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, - summary_md5_select) + compare_mode) def highlight_rows_xlsx(result_df, highlight_dict, file_path): @@ -698,6 +722,12 @@ def compare_core(input_parma, output_path, **kwargs): fuzzy_match = kwargs.get('fuzzy_match', False) summary_compare = kwargs.get('summary_compare', False) md5_compare = kwargs.get('md5_compare', False) + if summary_compare: + compare_mode = Const.SUMMARY + elif md5_compare: + compare_mode = Const.MD5 + else: + compare_mode = Const.ALL logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") file_name = add_time_with_xlsx("compare_result" + suffix) @@ -708,12 +738,16 @@ def compare_core(input_parma, output_path, **kwargs): with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \ FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \ FileOpen(input_parma.get("stack_json_path"), "r") as stack_json: + # result_df = compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, + # summary_compare, md5_compare) result_df = compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, - summary_compare, md5_compare) + compare_mode) - if not md5_compare and not summary_compare: + # if not md5_compare and not summary_compare: + if compare_mode == Const.ALL: result_df = _do_multi_process(input_parma, result_df) - find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare) + # find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare) + find_compare_result_error_rows(result_df, highlight_dict, compare_mode) highlight_rows_xlsx(result_df, highlight_dict, file_path) if auto_analyze: advisor = Advisor(result_df, output_path) @@ -891,7 +925,8 @@ def read_op(op_data, op_name): return op_parsed_list -def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False): +# def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False): +def compare_process(file_handles, stack_mode, fuzzy_match, compare_mode): npu_json_handle, bench_json_handle, stack_json_handle = file_handles npu_json_data = json.load(npu_json_handle) bench_json_data = json.load(bench_json_handle) @@ -926,7 +961,8 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False else: npu_op_parsed_list.append({'full_op_name': op_name_npu, 'full_info': None}) - npu_merge_list = merge_tensor(npu_op_parsed_list, summary_compare, md5_compare) + # npu_merge_list = merge_tensor(npu_op_parsed_list, summary_compare, md5_compare) + npu_merge_list = merge_tensor(npu_op_parsed_list, compare_mode) if npu_merge_list: npu_ops_queue.append(npu_merge_list) except StopIteration: @@ -943,7 +979,8 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False else: bench_op_parsed_list.append({'full_op_name': op_name_bench, 'full_info': None}) - bench_merge_list = merge_tensor(bench_op_parsed_list, summary_compare, md5_compare) + # bench_merge_list = merge_tensor(bench_op_parsed_list, summary_compare, md5_compare) + bench_merge_list = merge_tensor(bench_op_parsed_list, compare_mode) if bench_merge_list: bench_ops_queue.append(bench_merge_list) except StopIteration: @@ -962,31 +999,37 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False b_match_data = bench_ops_queue[b_match_point] un_match_data = npu_ops_queue[0: n_match_point] for npu_data in un_match_data: - get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) - get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare) + # get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) + get_un_match_accuracy(result, npu_data, compare_mode) + # get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare) + get_accuracy(result, n_match_data, b_match_data, compare_mode) del npu_ops_queue[0: n_match_point + 1] del bench_ops_queue[0: b_match_point + 1] if npu_ops_queue: for npu_data in npu_ops_queue: - get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) - - header = [] - if md5_compare: - header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] - elif summary_compare: - header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] - else: - header = CompareConst.COMPARE_RESULT_HEADER[:] - - all_mode_bool = not (summary_compare or md5_compare) + # get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) + get_un_match_accuracy(result, npu_data, compare_mode) + + # header = [] + # if md5_compare: + # header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] + # elif summary_compare: + # header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] + # else: + # header = CompareConst.COMPARE_RESULT_HEADER[:] + header = CompareConst.HEAD_OF_COMPARE_MODE.get(compare_mode) + + # all_mode_bool = not (summary_compare or md5_compare) if stack_mode: - if all_mode_bool: + # if all_mode_bool: + if compare_mode == Const.ALL: header.append(CompareConst.STACK) header.append(CompareConst.DATA_NAME) else: header.append(CompareConst.STACK) else: - if all_mode_bool: + # if all_mode_bool: + if compare_mode == Const.ALL: for row in result: del row[-2] header.append(CompareConst.DATA_NAME) @@ -998,7 +1041,8 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False return result_df -def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): +# def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): +def get_un_match_accuracy(result, n_dict, compare_mode): index_out = 0 npu_stack_info = n_dict.get("stack_info", None) bench_name, bench_type, bench_shape = CompareConst.NAN, CompareConst.NAN, CompareConst.NAN @@ -1012,13 +1056,15 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): index_out += 1 result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape] - if md5_compare: + # if md5_compare: + if compare_mode == Const.MD5: result_item.extend([CompareConst.NAN] * 3) if npu_stack_info and index == 0: result_item.extend(npu_stack_info) result.append(result_item) continue - if summary_compare: + # if summary_compare: + if compare_mode == Const.SUMMARY: result_item.extend([CompareConst.NAN] * 8) else: result_item.extend([CompareConst.NAN] * 5) @@ -1030,7 +1076,8 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): result_item.append(err_msg) if npu_stack_info and index == 0: result_item.extend(npu_stack_info) - if not md5_compare and not summary_compare and result_item[1] == CompareConst.NAN: + # if not md5_compare and not summary_compare and result_item[1] == CompareConst.NAN: + if compare_mode == Const.ALL and result_item[1] == CompareConst.NAN: if index == 0: result_item.extend(["-1"]) else: diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py index 2543c136fe..2dc15c1a45 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py @@ -2,20 +2,23 @@ import math import abc import numpy as np from msprobe.core.common.utils import get_header_index -from msprobe.core.common.const import CompareConst +from msprobe.core.common.const import CompareConst, Const class HighlightCheck(abc.ABC): @abc.abstractmethod - def apply(self, info, color_columns, summary_compare, md5_compare=False): + # def apply(self, info, color_columns, summary_compare, md5_compare=False): + def apply(self, info, color_columns, compare_mode): raise NotImplementedError class CheckOrderMagnitude(HighlightCheck): """检查Max diff的数量级差异""" - def apply(self, info, color_columns, summary_compare, md5_compare=False): + # def apply(self, info, color_columns, summary_compare, md5_compare=False): + def apply(self, info, color_columns, compare_mode): api_in, api_out, num = info - max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) + # max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) + max_diff_index = get_header_index('Max diff' if compare_mode == Const.SUMMARY else 'MaxAbsErr', compare_mode) if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]): return in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index])) @@ -26,9 +29,11 @@ class CheckOrderMagnitude(HighlightCheck): class CheckOneThousandErrorRatio(HighlightCheck): """检查千分误差比率""" - def apply(self, info, color_columns, summary_compare, md5_compare=False): + # def apply(self, info, color_columns, summary_compare, md5_compare=False): + def apply(self, info, color_columns, compare_mode): api_in, api_out, num = info - one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare) + # one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare) + one_thousand_index = get_header_index('One Thousandth Err Ratio', compare_mode) if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)): return if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED: @@ -39,9 +44,11 @@ class CheckOneThousandErrorRatio(HighlightCheck): class CheckCosineSimilarity(HighlightCheck): """检查余弦相似度""" - def apply(self, info, color_columns, summary_compare, md5_compare=False): + # def apply(self, info, color_columns, summary_compare, md5_compare=False): + def apply(self, info, color_columns, compare_mode): api_in, api_out, num = info - cosine_index = get_header_index('Cosine', summary_compare) + # cosine_index = get_header_index('Cosine', summary_compare) + cosine_index = get_header_index('Cosine', compare_mode) if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)): return if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW: @@ -50,10 +57,13 @@ class CheckCosineSimilarity(HighlightCheck): class CheckMaxRelativeDiff(HighlightCheck): """检查最大相对差异""" - def apply(self, info, color_columns, summary_compare, md5_compare=False): + # def apply(self, info, color_columns, summary_compare, md5_compare=False): + def apply(self, info, color_columns, compare_mode): api_in, api_out, num = info - max_diff_index = get_header_index('Max diff', summary_compare) - bench_max_index = get_header_index('Bench max', summary_compare) + # max_diff_index = get_header_index('Max diff', summary_compare) + # bench_max_index = get_header_index('Bench max', summary_compare) + max_diff_index = get_header_index('Max diff', compare_mode) + bench_max_index = get_header_index('Bench max', compare_mode) input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index]))) output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index]))) if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff, @@ -67,10 +77,13 @@ class CheckMaxRelativeDiff(HighlightCheck): class CheckMd5Diff(HighlightCheck): """检查md5值差异""" - def apply(self, info, color_columns, summary_compare, md5_compare=False): + # def apply(self, info, color_columns, summary_compare, md5_compare=False): + def apply(self, info, color_columns, compare_mode): line, num = info - npu_md5_index = get_header_index('NPU MD5', summary_compare, md5_compare) - bench_md5_index = get_header_index('BENCH MD5', summary_compare, md5_compare) + # npu_md5_index = get_header_index('NPU MD5', summary_compare, md5_compare) + # bench_md5_index = get_header_index('BENCH MD5', summary_compare, md5_compare) + npu_md5_index = get_header_index('NPU MD5', compare_mode) + bench_md5_index = get_header_index('BENCH MD5', compare_mode) if str(line[npu_md5_index]) != str(line[bench_md5_index]): color_columns.red.append(num) return @@ -78,11 +91,15 @@ class CheckMd5Diff(HighlightCheck): class CheckOverflow(HighlightCheck): """检查是否存在溢出""" - def apply(self, info, color_columns, summary_compare, md5_compare=False): + # def apply(self, info, color_columns, summary_compare, md5_compare=False): + def apply(self, info, color_columns, compare_mode): line, num = info - npu_max_index = get_header_index('NPU max', summary_compare) - npu_min_index = get_header_index('NPU min', summary_compare) - max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) + # npu_max_index = get_header_index('NPU max', summary_compare) + # npu_min_index = get_header_index('NPU min', summary_compare) + # max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) + npu_max_index = get_header_index('NPU max', compare_mode) + npu_min_index = get_header_index('NPU min', compare_mode) + max_diff_index = get_header_index('Max diff' if compare_mode == Const.SUMMARY else 'MaxAbsErr', compare_mode) if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str( line[npu_min_index]) in CompareConst.OVERFLOW_LIST: color_columns.red.append(num) -- Gitee From 201a6f56f56fe355e2505da1b9d0bd95ce114624 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Tue, 23 Jul 2024 16:02:19 +0800 Subject: [PATCH 4/7] review fix add compare_mode --- .../msprobe/core/common/const.py | 6 +- .../msprobe/core/common/utils.py | 14 ----- .../msprobe/pytorch/compare/acc_compare.py | 61 +------------------ .../msprobe/pytorch/compare/highlight.py | 17 ------ 4 files changed, 6 insertions(+), 92 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 61d4cb5b52..08e017e6ea 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -141,9 +141,9 @@ class CompareConst: ] HEAD_OF_COMPARE_MODE = { - "summary": CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:], - "md5": CompareConst.MD5_COMPARE_RESULT_HEADER[:], - "all": CompareConst.COMPARE_RESULT_HEADER[:] + "summary": SUMMARY_COMPARE_RESULT_HEADER[:], + "md5": MD5_COMPARE_RESULT_HEADER[:], + "all": COMPARE_RESULT_HEADER[:] } # compare standard diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 6d98165f51..5dd94e69f9 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -501,22 +501,8 @@ def task_dumppath_get(input_param): return summary_compare, md5_compare -# HEAD_OF_COMPARE_MODE = { -# "summary": CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:], -# "md5": CompareConst.MD5_COMPARE_RESULT_HEADER[:], -# "all": CompareConst.COMPARE_RESULT_HEADER[:] -# } - - -# def get_header_index(header_name, summary_compare=False, md5_compare=False): def get_header_index(header_name, compare_mode): header = CompareConst.HEAD_OF_COMPARE_MODE.get(compare_mode) - # if summary_compare: - # header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] - # elif md5_compare: - # header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] - # else: - # header = CompareConst.COMPARE_RESULT_HEADER[:] if header_name not in header: logger.error(f"{header_name} not in data name") raise CompareException(CompareException.INVALID_PARAM_ERROR) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index b011cacc19..fc1ce1906f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -130,7 +130,6 @@ def rename_api(npu_name, process): return torch_func -# def merge_tensor(tensor_list, summary_compare, md5_compare): def merge_tensor(tensor_list, compare_mode): op_dict = {} op_dict["op_name"] = [] @@ -140,8 +139,6 @@ def merge_tensor(tensor_list, compare_mode): op_dict["summary"] = [] op_dict["stack_info"] = [] - # all_mode_bool = not (summary_compare or md5_compare) - # if all_mode_bool: if compare_mode == Const.ALL: op_dict["data_name"] = [] @@ -150,7 +147,6 @@ def merge_tensor(tensor_list, compare_mode): op_dict['stack_info'].append(tensor['full_info']) break op_dict["op_name"].append(tensor['full_op_name']) - # if not md5_compare: if compare_mode != Const.MD5: if tensor['full_op_name'].find("input") != -1: op_dict["input_struct"].append((tensor['dtype'], tensor['shape'])) @@ -168,7 +164,7 @@ def merge_tensor(tensor_list, compare_mode): op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']]) - if all_mode_bool: + if compare_mode == Const.ALL: op_dict["data_name"].append(tensor['data_name']) if not op_dict["kwargs_struct"]: @@ -196,20 +192,16 @@ def get_accuracy(result, n_dict, b_dict, compare_mode): bench_stack_info = b_dict.get("stack_info", None) has_stack = npu_stack_info and bench_stack_info - # all_mode_bool = not (summary_compare or md5_compare) - # if all_mode_bool: if compare_mode == Const.ALL: npu_data_name = n_dict.get("data_name", None) bench_data_name = b_dict.get("data_name", None) for index in range(min_len): - n_name = n_dict['op_name'][n_start + index] b_name = b_dict['op_name'][b_start + index] n_struct = n_dict[key][index] b_struct = b_dict[key][index] err_msg = "" - # if md5_compare: if compare_mode == Const.MD5: result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], n_struct[2], b_struct[2], @@ -221,7 +213,6 @@ def get_accuracy(result, n_dict, b_dict, compare_mode): result.append(result_item) continue - # if summary_compare: if compare_mode == Const.SUMMARY: result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], " ", " ", " ", " ", " ", " ", " ", " "] @@ -234,7 +225,6 @@ def get_accuracy(result, n_dict, b_dict, compare_mode): bench_summary_data = b_dict.get("summary")[b_start + index] result_item.extend(bench_summary_data) - # if summary_compare: if compare_mode == Const.SUMMARY: start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF) warning_flag = False @@ -258,14 +248,13 @@ def get_accuracy(result, n_dict, b_dict, compare_mode): if str(result_item[i]) in ('inf', '-inf', 'nan'): result_item[i] = f'{result_item[i]}\t' - # result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES) result_item.append(accuracy_check if compare_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES) result_item.append(err_msg) if has_stack and index == 0 and key == "input_struct": result_item.extend(npu_stack_info) else: result_item.append(CompareConst.NONE) - if all_mode_bool: + if compare_mode == Const.ALL: result_item.append(npu_data_name[n_start + index]) result.append(result_item) @@ -274,7 +263,6 @@ def get_accuracy(result, n_dict, b_dict, compare_mode): for index in range(b_len, n_len): n_name = n_dict['op_name'][n_start + index] n_struct = n_dict[key][index] - # if md5_compare: if compare_mode == Const.MD5: result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN] @@ -536,7 +524,6 @@ def handle_inf_nan(n_value, b_value): return n_value, b_value -# def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_md5_select): def find_error_rows(result, last_len, n_num_input, highlight_dict, compare_mode): """找到单个API中需要高亮的行""" red_lines, yellow_lines = [], [] @@ -545,24 +532,17 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, compare_mode) ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - # summary_compare = summary_md5_select[0] - # md5_compare = summary_md5_select[1] - # 对单行API的输入或输出进行md5值差异判断 - # if md5_compare: if compare_mode == Const.MD5: for i, line in enumerate(result): num = last_len + i line_info = LineInfo(line_data=line, num_pointer=num) for rule in HighlightRules.md5_compare_rules.values(): - rule.apply(line_info, color_columns, summary_compare, md5_compare) + rule.apply(line_info, color_columns, compare_mode) highlight_dict.get('red_rows', []).extend(list(set(red_lines))) highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines))) return - # npu_max_index = get_header_index('NPU max', summary_compare) - # bench_max_index = get_header_index('Bench max', summary_compare) - # max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) npu_max_index = get_header_index('NPU max', compare_mode) bench_max_index = get_header_index('Bench max', compare_mode) max_diff_index = get_header_index('Max diff' if compare_mode == Const.SUMMARY else 'MaxAbsErr', compare_mode) @@ -572,7 +552,6 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, compare_mode) num = last_len + i line_info = LineInfo(line_data=line, num_pointer=num) for rule in HighlightRules.basic_rules.values(): - # rule.apply(line_info, color_columns, summary_compare) rule.apply(line_info, color_columns, compare_mode) # 对API的输出与输入比较,进行误差判断 @@ -591,14 +570,11 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, compare_mode) continue api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num) - # if summary_compare: if compare_mode == Const.SUMMARY: for rule in HighlightRules.summary_compare_rules.values(): - # rule.apply(api_info, color_columns, summary_compare) rule.apply(api_info, color_columns, compare_mode) else: for rule in HighlightRules.compare_rules.values(): - # rule.apply(api_info, color_columns, summary_compare) rule.apply(api_info, color_columns, compare_mode) highlight_dict.get('red_rows', []).extend(list(set(red_lines))) @@ -616,10 +592,8 @@ def get_name_and_state(name): return api_name, state -# def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare): def find_compare_result_error_rows(result_df, highlight_dict, compare_mode): """将dataframe根据API分组,并找到有误差的算子用于高亮""" - # summary_md5_select = [summary_compare, md5_compare] result = result_df.values start, input_num, output_num, end = 0, 0, 0, len(result_df) last_api_name, last_state = None, None @@ -635,8 +609,6 @@ def find_compare_result_error_rows(result_df, highlight_dict, compare_mode): num, last_state = 1, state else: output_num = num - # find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, - # summary_md5_select) find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, compare_mode) num, last_api_name, last_state = 1, api_name, state @@ -649,8 +621,6 @@ def find_compare_result_error_rows(result_df, highlight_dict, compare_mode): input_num = num else: output_num = num - # find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, - # summary_md5_select) find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, compare_mode) @@ -738,15 +708,11 @@ def compare_core(input_parma, output_path, **kwargs): with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \ FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \ FileOpen(input_parma.get("stack_json_path"), "r") as stack_json: - # result_df = compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, - # summary_compare, md5_compare) result_df = compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, compare_mode) - # if not md5_compare and not summary_compare: if compare_mode == Const.ALL: result_df = _do_multi_process(input_parma, result_df) - # find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare) find_compare_result_error_rows(result_df, highlight_dict, compare_mode) highlight_rows_xlsx(result_df, highlight_dict, file_path) if auto_analyze: @@ -925,7 +891,6 @@ def read_op(op_data, op_name): return op_parsed_list -# def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False): def compare_process(file_handles, stack_mode, fuzzy_match, compare_mode): npu_json_handle, bench_json_handle, stack_json_handle = file_handles npu_json_data = json.load(npu_json_handle) @@ -961,7 +926,6 @@ def compare_process(file_handles, stack_mode, fuzzy_match, compare_mode): else: npu_op_parsed_list.append({'full_op_name': op_name_npu, 'full_info': None}) - # npu_merge_list = merge_tensor(npu_op_parsed_list, summary_compare, md5_compare) npu_merge_list = merge_tensor(npu_op_parsed_list, compare_mode) if npu_merge_list: npu_ops_queue.append(npu_merge_list) @@ -979,7 +943,6 @@ def compare_process(file_handles, stack_mode, fuzzy_match, compare_mode): else: bench_op_parsed_list.append({'full_op_name': op_name_bench, 'full_info': None}) - # bench_merge_list = merge_tensor(bench_op_parsed_list, summary_compare, md5_compare) bench_merge_list = merge_tensor(bench_op_parsed_list, compare_mode) if bench_merge_list: bench_ops_queue.append(bench_merge_list) @@ -999,36 +962,22 @@ def compare_process(file_handles, stack_mode, fuzzy_match, compare_mode): b_match_data = bench_ops_queue[b_match_point] un_match_data = npu_ops_queue[0: n_match_point] for npu_data in un_match_data: - # get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) get_un_match_accuracy(result, npu_data, compare_mode) - # get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare) get_accuracy(result, n_match_data, b_match_data, compare_mode) del npu_ops_queue[0: n_match_point + 1] del bench_ops_queue[0: b_match_point + 1] if npu_ops_queue: for npu_data in npu_ops_queue: - # get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) get_un_match_accuracy(result, npu_data, compare_mode) - # header = [] - # if md5_compare: - # header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] - # elif summary_compare: - # header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] - # else: - # header = CompareConst.COMPARE_RESULT_HEADER[:] header = CompareConst.HEAD_OF_COMPARE_MODE.get(compare_mode) - - # all_mode_bool = not (summary_compare or md5_compare) if stack_mode: - # if all_mode_bool: if compare_mode == Const.ALL: header.append(CompareConst.STACK) header.append(CompareConst.DATA_NAME) else: header.append(CompareConst.STACK) else: - # if all_mode_bool: if compare_mode == Const.ALL: for row in result: del row[-2] @@ -1041,7 +990,6 @@ def compare_process(file_handles, stack_mode, fuzzy_match, compare_mode): return result_df -# def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): def get_un_match_accuracy(result, n_dict, compare_mode): index_out = 0 npu_stack_info = n_dict.get("stack_info", None) @@ -1056,14 +1004,12 @@ def get_un_match_accuracy(result, n_dict, compare_mode): index_out += 1 result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape] - # if md5_compare: if compare_mode == Const.MD5: result_item.extend([CompareConst.NAN] * 3) if npu_stack_info and index == 0: result_item.extend(npu_stack_info) result.append(result_item) continue - # if summary_compare: if compare_mode == Const.SUMMARY: result_item.extend([CompareConst.NAN] * 8) else: @@ -1076,7 +1022,6 @@ def get_un_match_accuracy(result, n_dict, compare_mode): result_item.append(err_msg) if npu_stack_info and index == 0: result_item.extend(npu_stack_info) - # if not md5_compare and not summary_compare and result_item[1] == CompareConst.NAN: if compare_mode == Const.ALL and result_item[1] == CompareConst.NAN: if index == 0: result_item.extend(["-1"]) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py index 2dc15c1a45..82e86cf3a1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py @@ -7,17 +7,14 @@ from msprobe.core.common.const import CompareConst, Const class HighlightCheck(abc.ABC): @abc.abstractmethod - # def apply(self, info, color_columns, summary_compare, md5_compare=False): def apply(self, info, color_columns, compare_mode): raise NotImplementedError class CheckOrderMagnitude(HighlightCheck): """检查Max diff的数量级差异""" - # def apply(self, info, color_columns, summary_compare, md5_compare=False): def apply(self, info, color_columns, compare_mode): api_in, api_out, num = info - # max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) max_diff_index = get_header_index('Max diff' if compare_mode == Const.SUMMARY else 'MaxAbsErr', compare_mode) if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]): return @@ -29,10 +26,8 @@ class CheckOrderMagnitude(HighlightCheck): class CheckOneThousandErrorRatio(HighlightCheck): """检查千分误差比率""" - # def apply(self, info, color_columns, summary_compare, md5_compare=False): def apply(self, info, color_columns, compare_mode): api_in, api_out, num = info - # one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare) one_thousand_index = get_header_index('One Thousandth Err Ratio', compare_mode) if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)): return @@ -44,10 +39,8 @@ class CheckOneThousandErrorRatio(HighlightCheck): class CheckCosineSimilarity(HighlightCheck): """检查余弦相似度""" - # def apply(self, info, color_columns, summary_compare, md5_compare=False): def apply(self, info, color_columns, compare_mode): api_in, api_out, num = info - # cosine_index = get_header_index('Cosine', summary_compare) cosine_index = get_header_index('Cosine', compare_mode) if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)): return @@ -57,11 +50,8 @@ class CheckCosineSimilarity(HighlightCheck): class CheckMaxRelativeDiff(HighlightCheck): """检查最大相对差异""" - # def apply(self, info, color_columns, summary_compare, md5_compare=False): def apply(self, info, color_columns, compare_mode): api_in, api_out, num = info - # max_diff_index = get_header_index('Max diff', summary_compare) - # bench_max_index = get_header_index('Bench max', summary_compare) max_diff_index = get_header_index('Max diff', compare_mode) bench_max_index = get_header_index('Bench max', compare_mode) input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index]))) @@ -77,11 +67,8 @@ class CheckMaxRelativeDiff(HighlightCheck): class CheckMd5Diff(HighlightCheck): """检查md5值差异""" - # def apply(self, info, color_columns, summary_compare, md5_compare=False): def apply(self, info, color_columns, compare_mode): line, num = info - # npu_md5_index = get_header_index('NPU MD5', summary_compare, md5_compare) - # bench_md5_index = get_header_index('BENCH MD5', summary_compare, md5_compare) npu_md5_index = get_header_index('NPU MD5', compare_mode) bench_md5_index = get_header_index('BENCH MD5', compare_mode) if str(line[npu_md5_index]) != str(line[bench_md5_index]): @@ -91,12 +78,8 @@ class CheckMd5Diff(HighlightCheck): class CheckOverflow(HighlightCheck): """检查是否存在溢出""" - # def apply(self, info, color_columns, summary_compare, md5_compare=False): def apply(self, info, color_columns, compare_mode): line, num = info - # npu_max_index = get_header_index('NPU max', summary_compare) - # npu_min_index = get_header_index('NPU min', summary_compare) - # max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) npu_max_index = get_header_index('NPU max', compare_mode) npu_min_index = get_header_index('NPU min', compare_mode) max_diff_index = get_header_index('Max diff' if compare_mode == Const.SUMMARY else 'MaxAbsErr', compare_mode) -- Gitee From a6473b3762cd05b0efde9f8b70e6caff350f17ac Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Tue, 23 Jul 2024 16:05:42 +0800 Subject: [PATCH 5/7] review fix add compare_mode --- debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py | 1 - 1 file changed, 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index fc1ce1906f..f2c4cf28ff 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -184,7 +184,6 @@ def match_op(npu_queue, bench_queue, fuzzy_match): return -1, -1 -# def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False): def get_accuracy(result, n_dict, b_dict, compare_mode): def get_accuracy_core(n_start, n_len, b_start, b_len, key): min_len = min(n_len, b_len) -- Gitee From 02620aeb0045cafa2aab06be7768b928c01d7898 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Tue, 23 Jul 2024 16:20:20 +0800 Subject: [PATCH 6/7] review fix add compare_mode --- debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index f2c4cf28ff..ced1393c53 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -282,7 +282,7 @@ def get_accuracy(result, n_dict, b_dict, compare_mode): result_item.extend(npu_stack_info) else: result_item.append(CompareConst.NONE) - if all_mode_bool: + if compare_mode == compare_mode.ALL: result_item.append(npu_data_name[n_start + index]) result.append(result_item) -- Gitee From 06f98f6dbbe92a963ae28fc469e481a72bb3ec1d Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Tue, 23 Jul 2024 16:22:19 +0800 Subject: [PATCH 7/7] review fix add compare_mode --- debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index ced1393c53..6bc4329740 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -282,7 +282,7 @@ def get_accuracy(result, n_dict, b_dict, compare_mode): result_item.extend(npu_stack_info) else: result_item.append(CompareConst.NONE) - if compare_mode == compare_mode.ALL: + if compare_mode == Const.ALL: result_item.append(npu_data_name[n_start + index]) result.append(result_item) -- Gitee