From 61c8801c2f3cbf5f7c3a474c14c8202a03385f36 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Thu, 28 Aug 2025 14:29:40 +0800 Subject: [PATCH] compare md5 diff check fit P2POp --- .../msprobe/core/compare/acc_compare.py | 113 +++++++-------- .../diff_analyze/first_diff_analyze.py | 18 ++- .../msprobe/core/compare/highlight.py | 4 +- .../msprobe/core/compare/utils.py | 73 ++++++---- .../test/core_ut/compare/test_acc_compare.py | 72 ++++++---- .../core_ut/compare/test_acc_compare_utils.py | 132 ++++++++++-------- .../compare/test_cmp_first_diff_analyze.py | 18 +-- .../core_ut/compare/test_cmp_highlight.py | 1 + 8 files changed, 238 insertions(+), 193 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 2846d413a9..21869ba870 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -31,8 +31,8 @@ from msprobe.core.common.utils import CompareException, add_time_with_xlsx, chec set_dump_path, get_dump_mode, check_compare_param, load_stack_json, get_file_type, add_time_with_json from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping, \ check_configuration_param -from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \ - reorder_op_x_list, set_stack_json_path, check_api_info_len +from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, set_stack_json_path, \ + reorder_index from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict from msprobe.core.compare.multiprocessing_compute import CompareRealData from msprobe.core.compare.highlight import HighLight @@ -110,13 +110,21 @@ class Comparator: npu_json = input_param.get("npu_json_path") bench_json = input_param.get("bench_json_path") stack_json = input_param.get("stack_json_path") - result_df = self.compare_statistics([npu_json, bench_json, stack_json], rank) + parse_data = ParseData(self.mode_config, rank) # load and parse json data + npu_df, bench_df = parse_data.parse([npu_json, bench_json, stack_json]) + result_df = self.compare_statistics(npu_df, bench_df) if not result_df.values.tolist(): logger.warning("Can`t match any op. No compare result file generated.") return if self.mode_config.first_diff_analyze: - first_diff_analyze = FirstDiffAnalyze(self.mode_config) + # add P2POp additional info from npu_df and bench_df to result_df + result_df['NPU P2POp op'] = npu_df['op'] + result_df['Bench P2POp op'] = bench_df['op'] + result_df['NPU P2POp peer'] = npu_df['peer'] + result_df['Bench P2POp peer'] = bench_df['peer'] + + first_diff_analyze = FirstDiffAnalyze(self.mode_config, rank) check_result = first_diff_analyze.check(result_df) save_json(file_path, check_result, indent=4) logger.info(f"Saving json file to disk: {file_path}") @@ -149,11 +157,7 @@ class Comparator: print_compare_ends_info() - def compare_statistics(self, file_list, rank): - # load and parse json data - parse_data = ParseData(self.mode_config, rank) - npu_df, bench_df = parse_data.parse(file_list) - + def compare_statistics(self, npu_df, bench_df): npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str) bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str) @@ -227,57 +231,42 @@ class ParseData: # 从json中循环解析API数据,遍历所有API for data_name in apis_data: check_op_str_pattern_valid(data_name) - merge_list = self.gen_merge_list(data_json, data_name, stack_json_data) - if not merge_list: + op_parsed_list = self.gen_merge_list(data_json, data_name, stack_json_data) + if not op_parsed_list: continue + reordered_index_list = reorder_index(op_parsed_list) + for i, index in enumerate(reordered_index_list): + op_item = op_parsed_list[index] + + # common key + result[CompareConst.OP_NAME].append(op_item.get('full_op_name')) + result[Const.DTYPE].append(op_item.get(Const.DTYPE)) + result[Const.SHAPE].append(op_item.get(Const.SHAPE)) + result[Const.STATE].append(op_item.get(Const.STATE)) + result[Const.REQ_GRAD].append(op_item.get(Const.REQ_GRAD)) + result[Const.API_ORIGIN_NAME].append(data_name) + summary_data = [ + str(op_item.get(key)) if op_item.get(key) is None else op_item.get(key) + for key in Const.SUMMARY_METRICS_LIST + ] + result[Const.SUMMARY].append(summary_data) - op_name_list = merge_list.get(CompareConst.OP_NAME) - summary_list = merge_list.get(Const.SUMMARY) - data_name_list = merge_list.get(Const.DATA_NAME) - state_list = merge_list.get(Const.STATE) - requires_grad_list = merge_list.get(Const.REQ_GRAD) - op_name_reorder, summary_reorder, data_name_reorder, state_reorder, requires_grad_reorder = ( - reorder_op_x_list(op_name_list, summary_list, data_name_list, state_list, requires_grad_list)) - # 遍历单个API的所有item - for index, (op_name, state) in enumerate(zip(op_name_reorder, state_reorder)): - result[CompareConst.OP_NAME].append(op_name) - if state == Const.INPUT or state == Const.KWARGS: - info_list = merge_list[CompareConst.INPUT_STRUCT] - elif state == Const.OUTPUT: - info_list = merge_list[CompareConst.OUTPUT_STRUCT] - elif state == Const.PARAMS: - info_list = merge_list[CompareConst.PARAMS_STRUCT] - elif state == Const.PARAMS_GRAD: - info_list = merge_list[CompareConst.PARAMS_GRAD_STRUCT] - else: - info_list = merge_list[CompareConst.DEBUG_STRUCT] - check_api_info_len(op_name, info_list, 1) - struct = info_list.pop(0) - - check_api_info_len(op_name, struct, 2) - result[Const.DTYPE].append(struct[0]) - result[Const.SHAPE].append(struct[1]) - - check_api_info_len(op_name, summary_reorder, 1) - result[Const.SUMMARY].append(summary_reorder.pop(0)) + # dump_mode differ key + if self.mode_config.dump_mode == Const.MD5: + result[Const.MD5].append(op_parsed_list[index].get(Const.MD5)) + if self.mode_config.dump_mode == Const.ALL: + result[Const.DATA_NAME].append(op_item.get(Const.DATA_NAME)) - if index == 0 and self.mode_config.stack_mode: - check_api_info_len(op_name, merge_list[Const.STACK_INFO], 1) - result[Const.STACK_INFO].append(merge_list[Const.STACK_INFO][0]) + # mode_config stack_mode addition key + if i == 0 and self.mode_config.stack_mode: + result[Const.STACK_INFO].append(op_parsed_list[-1].get('full_info')) else: result[Const.STACK_INFO].append(None) - if self.mode_config.dump_mode == Const.MD5: - check_api_info_len(op_name, struct, 3) - result[Const.MD5].append(struct[2]) - if self.mode_config.dump_mode == Const.ALL: - check_api_info_len(op_name, data_name_reorder, 1) - result[Const.DATA_NAME].append(data_name_reorder.pop(0)) - - result[Const.STATE].append(state) - result[Const.API_ORIGIN_NAME].append(data_name) - check_api_info_len(op_name, requires_grad_reorder, 1) - result[Const.REQ_GRAD].append(requires_grad_reorder.pop(0)) + # mode_config first_diff_analyze addition key + if self.mode_config.first_diff_analyze: + result.setdefault('op', []).append(op_item.get('op', str(None))) + result.setdefault('peer', []).append(op_item.get('peer', str(None))) progress_bar.update(1) progress_bar.close() @@ -293,14 +282,14 @@ class ParseData: stack_info = stack_json_data.get(op_name) if stack_info is not None: check_stack_json_str(stack_info, op_name) - # append only when stack_mode is True, - op_parsed_list.append({ - 'full_op_name': op_name, - 'full_info': stack_info - }) - - merge_list = merge_tensor(op_parsed_list, self.mode_config.dump_mode) - return merge_list + else: + stack_info = None + # always add stack_info whether stack_mode is True + op_parsed_list.append({ + 'full_op_name': op_name, + 'full_info': stack_info + }) + return op_parsed_list class ProcessDf: diff --git a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py index f1192d8951..924312a2bc 100644 --- a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py +++ b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py @@ -15,6 +15,8 @@ import os +from tqdm import tqdm + from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.utils import logger, CompareException from msprobe.core.common.file_utils import load_yaml @@ -29,8 +31,9 @@ cmp_metrics = thresholds.get('compare_metrics') class FirstDiffAnalyze: - def __init__(self, mode_config: ModeConfig): + def __init__(self, mode_config: ModeConfig, rank): self.mode_config = mode_config + self.rank = rank @staticmethod def single_metric_diff_check(cmp_metric, metric_value): @@ -105,11 +108,16 @@ class FirstDiffAnalyze: result = result_df.values header = result_df.columns.tolist() - api_batches = gen_api_batches(result) + api_batches = gen_api_batches(result, header) check_result = {} - for api_batch in api_batches: - result_slice = result[api_batch.start: api_batch.params_grad_end_index] - check_result[api_batch.api_name] = self.single_api_check(result_slice, header) + + default_bar_desc = 'API/Module diff check Progress' + bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc + with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="api/module", ncols=100) as progress_bar: + for api_batch in api_batches: + result_slice = result[api_batch.start: api_batch.params_grad_end_index] + check_result[api_batch.api_name] = self.single_api_check(result_slice, header) + progress_bar.update(1) return check_result diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 64c599b009..6e5aaa232d 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -18,7 +18,6 @@ import math import multiprocessing from collections import namedtuple -import numpy as np import openpyxl from openpyxl.styles import PatternFill from openpyxl.utils.dataframe import dataframe_to_rows @@ -241,7 +240,8 @@ class HighLight: def find_compare_result_error_rows(self, result_df, highlight_dict): """将dataframe根据API分组,并找到有误差的算子用于高亮""" result = result_df.values - api_batches = gen_api_batches(result) + header = result_df.columns.tolist() + api_batches = gen_api_batches(result, header) default_bar_desc = 'API/Module Analyse Progress' bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="item", ncols=100) as progress_bar: diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 2e7cdd5ae7..5503bdd3fe 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -147,6 +147,17 @@ def op_item_parse(op_data, op_name: str, state: str, depth: int = 0) -> list: else: item_list.extend(op_item_parse(data, op_name, state, depth + 1)) elif isinstance(op_data, dict): + if is_p2pop_leaf_data(op_data): + p2pop_item = {} + for key in ['class_type', 'op', 'peer', 'tag', 'group_id']: + p2pop_item[key] = op_data.get(key) + op_data = op_data.get('tensor') + if isinstance(op_data, dict): + op_item = gen_op_item(op_data, op_name, state) + else: + op_item = default_item + op_item.update(p2pop_item) + return [op_item] if is_leaf_data(op_data): return [gen_op_item(op_data, op_name, state)] for sub_name, sub_data in op_data.items(): @@ -154,6 +165,10 @@ def op_item_parse(op_data, op_name: str, state: str, depth: int = 0) -> list: return item_list +def is_p2pop_leaf_data(op_data): + return op_data.get('class_type') == 'torch.distributed.P2POp' + + def is_leaf_data(op_data): return 'type' in op_data and isinstance(op_data['type'], str) @@ -267,12 +282,6 @@ def merge_tensor(tensor_list, dump_mode): return op_dict if op_dict[CompareConst.OP_NAME] else {} -def check_api_info_len(op_name, info_list, len_require): - if len(info_list) < len_require: - logger.error(f'Index out of bounds error, please check info of api: {op_name}.') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - - def print_compare_ends_info(): total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS logger.info('*' * total_len) @@ -347,6 +356,29 @@ def api_batches_update(api_batches, api_name, state, index): api_batches.append(ApiBatch(api_name, index)) +def reorder_index(op_parsed_list): + """ + 对单个api解析的op_items的index进行重排,将parameter的index放到output前面,返回新的重排后的index列表,op_parsed_list不变 + """ + index_param = [] + index_output = [] + index_param_grad = [] + index_other = [] + for i, op_item in enumerate(op_parsed_list[:-1]): + state = op_item.get(Const.STATE) + if state == Const.PARAMS: + index_param.append(i) + elif state == Const.OUTPUT: + index_output.append(i) + elif state == Const.PARAMS_GRAD: + index_param_grad.append(i) + else: + index_other.append(i) + # 合并others, parameters, 和output,确保parameters排在output前面 + reordered_index_list = index_other + index_param + index_output + index_param_grad + return reordered_index_list + + def reorder_op_name_list(op_name_list, state_list): if not op_name_list: return op_name_list, state_list @@ -378,27 +410,6 @@ def reorder_op_name_list(op_name_list, state_list): return op_name_reorder, state_reorder -def reorder_op_x_list(op_name_list, summary_list, data_name_list, state_list, requires_grad_list): - """ - 对op_name, summary, data_name, state, requires_grad重新排序, - 把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理 - """ - if not op_name_list or not summary_list: - return op_name_list, summary_list, data_name_list, state_list, requires_grad_list - - index_map = {name: index for index, name in enumerate(op_name_list)} - - op_name_reorder, state_reorder = reorder_op_name_list(op_name_list, state_list) - summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder] - requires_grad_reorder = [requires_grad_list[index_map.get(name)] for name in op_name_reorder] - if data_name_list: - data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder] - else: - data_name_reorder = data_name_list - - return op_name_reorder, summary_reorder, data_name_reorder, state_reorder, requires_grad_reorder - - def process_summary_data(summary_data): """处理summary_data中的nan值,返回处理后的列表""" return [CompareConst.NAN if isinstance(x, float) and math.isnan(x) else x for x in summary_data] @@ -621,11 +632,13 @@ def make_result_table(result, dump_mode, stack_mode): return result_df -def gen_api_batches(result: np.ndarray): +def gen_api_batches(result: np.ndarray, header: list): + api_name_index = header.index(Const.API_ORIGIN_NAME) + state_name_index = header.index(Const.STATE) api_batches = [] for i, res_i in enumerate(result): - api_name = safe_get_value(res_i, -1, "res_i") # 内部定义倒数第一个元素必是api_origin_name - state = safe_get_value(res_i, -2, "res_i") # 内部定义倒数第二个元素必是state + api_name = safe_get_value(res_i, api_name_index, "res_i") + state = safe_get_value(res_i, state_name_index, "res_i") api_batches_update(api_batches, api_name, state, i) return api_batches diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index c00d5a061d..3cbef9af6c 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -342,18 +342,26 @@ class TestUtilsMethods(unittest.TestCase): json_data = {'data': {'Functional.linear.0.forward': op_data}} op_name = 'Functional.linear.0.forward' stack_json_data = {'Functional.linear.0.forward': ['File']} - merge_list = { - 'debug_struct': [], - 'input_struct': [('torch.float32', [2, 2])], - 'op_name': ['Functional.linear.0.forward.input.0'], - 'output_struct': [], - 'params_struct': [], - 'params_grad_struct': [], - 'stack_info': [['File']], - 'summary': [[1, 1, 1, 1]], - 'state': ['input'], - 'requires_grad': ['False'] - } + target_merge_list = [ + { + 'full_op_name': 'Functional.linear.0.forward.input.0', + 'type': 'torch.Tensor', + 'dtype': 'torch.float32', + 'shape': [2, 2], + 'requires_grad': 'False', + 'Max': 1, + 'Min': 1, + 'Mean': 1, + 'Norm': 1, + 'md5': '00000000', + 'data_name': 'Functional.linear.0.forward.input.0.pt', + 'state': 'input' + }, + { + 'full_op_name': 'Functional.linear.0.forward', + 'full_info': ['File'] + } + ] config_dict = { 'stack_mode': True, @@ -364,7 +372,7 @@ class TestUtilsMethods(unittest.TestCase): mode_config = ModeConfig(**config_dict) result = ParseData(mode_config, 'rank0').gen_merge_list(json_data, op_name, stack_json_data) - self.assertEqual(result, merge_list) + self.assertEqual(result, target_merge_list) def test_check_op_item_fuzzy(self): config_dict = { @@ -397,7 +405,9 @@ class TestUtilsMethods(unittest.TestCase): from msprobe.pytorch.compare.pt_compare import read_real_data comparator = Comparator(read_real_data, mode_config, mapping_config) - result = comparator.compare_statistics(file_list, 'rank0') + parse_data = ParseData(mode_config, '') + npu_df, bench_df = parse_data.parse(file_list) + result = comparator.compare_statistics(npu_df, bench_df) o_data = [ ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', '[2, 2]', '[2, 2]', 'False', 'False', @@ -506,19 +516,27 @@ class TestParseData(unittest.TestCase): parse_data = ParseData(mode_config, 'rank0') merge_list = parse_data.gen_merge_list(npu_json_data, 'Functional.linear.0.forward', stack_json_data) - target_dict = { - 'debug_struct': [], - 'input_struct': [('torch.float32', [2, 2])], - 'op_name': ['Functional.linear.0.forward.input.0'], - 'output_struct': [], - 'params_grad_struct': [], - 'params_struct': [], - 'stack_info': [['File']], - 'summary': [[2, 0, 1, 1]], - 'state': ['input'], - 'requires_grad': ['False'] - } - self.assertEqual(merge_list, target_dict) + target_merge_list = [ + { + 'full_op_name': 'Functional.linear.0.forward.input.0', + 'type': 'torch.Tensor', + 'dtype': 'torch.float32', + 'shape': [2, 2], + 'requires_grad': 'False', + 'Max': 2, + 'Min': 0, + 'Mean': 1, + 'Norm': 1, + 'md5': '00000000', + 'data_name': 'Functional.linear.0.forward.input.0.pt', + 'state': 'input' + }, + { + 'full_op_name': 'Functional.linear.0.forward', + 'full_info': ['File'] + } + ] + self.assertEqual(merge_list, target_merge_list) class TestProcessDf(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index e5ed872592..286fecf84d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -9,13 +9,14 @@ import zlib import tempfile import numpy as np +import pandas as pd from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.utils import CompareException from msprobe.core.compare.utils import ApiItemInfo, _compare_parser, check_and_return_dir_contents, extract_json, \ count_struct, get_accuracy, get_rela_diff_summary_mode, merge_tensor, op_item_parse, read_op, result_item_init, \ - stack_column_process, table_value_is_valid, reorder_op_name_list, reorder_op_x_list, gen_op_item, ApiBatch, \ - get_paired_dirs + stack_column_process, table_value_is_valid, reorder_op_name_list, gen_op_item, ApiBatch, get_paired_dirs, \ + reorder_index, gen_api_batches # test_read_op_1 op_data = { @@ -589,6 +590,40 @@ class TestUtilsMethods(unittest.TestCase): self.assertFalse(result) +class TestReorderIndex(unittest.TestCase): + def test_reorder_index_mixed_states(self): + op_parsed_list = [ + {Const.STATE: "OTHER"}, + {Const.STATE: Const.OUTPUT}, + {Const.STATE: Const.PARAMS}, + {Const.STATE: Const.PARAMS_GRAD}, + {Const.STATE: Const.INPUT}, + {"not_state": 123}, # 没有 STATE,算作 other + ] + + reordered = reorder_index(op_parsed_list) + self.assertTrue(reordered == [0, 4, 2, 1, 3]) + + def test_reorder_index_all_params(self): + op_parsed_list = [ + {Const.STATE: Const.PARAMS}, + {Const.STATE: Const.PARAMS}, + {Const.STATE: Const.PARAMS}, + ] + reordered = reorder_index(op_parsed_list) + self.assertTrue(reordered == [0, 1]) + + def test_reorder_index_empty(self): + op_parsed_list = [] + reordered = reorder_index(op_parsed_list) + self.assertTrue(reordered == []) + + def test_reorder_index_single_element(self): + op_parsed_list = [{Const.STATE: Const.PARAMS}] + reordered = reorder_index(op_parsed_list) + self.assertTrue(reordered == []) + + class TestReorderOpNameList(unittest.TestCase): def test_reorder_op_name_list(self): # 标准顺序 @@ -621,62 +656,6 @@ class TestReorderOpNameList(unittest.TestCase): self.assertEqual(state_reorder, expected_state) -class TestReorderOpXList(unittest.TestCase): - def test_reorder_op_x_list(self): - # 标准顺序 - op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] - summary_list = ["summary1", "summary2", "summary3"] - data_name_list = ["data1", "data2", "data3"] - state_list = ["input", "output", "parameters"] - requires_grad_list = [True, None, False] - result_op_name, result_summary, result_data_name, result_state, result_requires_grad = reorder_op_x_list( - op_name_list, summary_list, data_name_list, state_list, requires_grad_list) - self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) - self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) - self.assertEqual(result_data_name, ["data1", "data3", "data2"]) - self.assertEqual(result_state, ["input", "parameters", "output"]) - self.assertEqual(result_requires_grad, [True, False, None]) - - # 空 op_name_list 或 summary_list - op_name_list = [] - summary_list = [] - data_name_list = ["data1", "data2", "data3"] - state_list = [] - result_op_name, result_summary, result_data_name, result_state, result_requires_grad = reorder_op_x_list( - op_name_list, summary_list, data_name_list, state_list, requires_grad_list) - self.assertEqual(result_op_name, []) - self.assertEqual(result_summary, []) - self.assertEqual(result_data_name, ["data1", "data2", "data3"]) - self.assertEqual(result_state, []) - self.assertEqual(result_requires_grad, [True, None, False]) - - # 空 data_name_list - op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] - summary_list = ["summary1", "summary2", "summary3"] - data_name_list = [] - state_list = ["input", "output", "parameters"] - result_op_name, result_summary, result_data_name, result_state, result_requires_grad = reorder_op_x_list( - op_name_list, summary_list, data_name_list, state_list, requires_grad_list) - self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) - self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) - self.assertEqual(result_data_name, []) - self.assertEqual(result_state, ["input", "parameters", "output"]) - self.assertEqual(result_requires_grad, [True, False, None]) - - # data_name_list 为 None - op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] - summary_list = ["summary1", "summary2", "summary3"] - data_name_list = None - state_list = ["input", "output", "parameters"] - result_op_name, result_summary, result_data_name, result_state, result_requires_grad = reorder_op_x_list( - op_name_list, summary_list, data_name_list, state_list, requires_grad_list) - self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) - self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) - self.assertEqual(result_data_name, None) - self.assertEqual(result_state, ["input", "parameters", "output"]) - self.assertEqual(result_requires_grad, [True, False, None]) - - class TestGenOpItem(unittest.TestCase): def test_gen_op_item_with_data_name(self): op_data = { @@ -914,6 +893,43 @@ class TestApiBatch(unittest.TestCase): self.assertEqual(api_batch.params_grad_end_index, 5) +class TestGenApiBatches(unittest.TestCase): + def test_gen_api_batches_normal(self): + result_df_part1 = pd.DataFrame(o_result) + result_df_part1.columns = CompareConst.SUMMARY_COMPARE_RESULT_HEADER_STACK + new_columns = [ + ['input', 'Functional.conv2d.0.forward'], + ['input', 'Functional.conv2d.0.forward'], + ['input', 'Functional.conv2d.0.forward'], + ['parameters', 'Functional.conv2d.0.forward'], + ['parameters', 'Functional.conv2d.0.forward'], + ['output', 'Functional.conv2d.0.forward'], + ['parameters_grad', 'Functional.conv2d.0.forward'], + ['parameters_grad', 'Functional.conv2d.0.forward'] + ] + result_df_part2 = pd.DataFrame(new_columns) + result_df_part2.columns = [Const.STATE, Const.API_ORIGIN_NAME] + result_df = pd.concat([result_df_part1, result_df_part2], axis=1) + result = result_df.values + header = result_df.columns.tolist() + result_api_batches = gen_api_batches(result, header) + + api_batch = ApiBatch('Functional.conv2d.0.forward', 0) + api_batch.input_len = 3 + api_batch.output_end_index = 6 + api_batch.params_end_index = 5 + api_batch.params_grad_end_index = 8 + api_batch._state = 'parameters_grad' + + result_api_batch = result_api_batches[0] + self.assertEqual(result_api_batch.api_name, api_batch.api_name) + self.assertEqual(result_api_batch.start, api_batch.start) + self.assertEqual(result_api_batch.input_len, api_batch.input_len) + self.assertEqual(result_api_batch.params_end_index, api_batch.params_end_index) + self.assertEqual(result_api_batch.params_grad_end_index, api_batch.params_grad_end_index) + self.assertEqual(result_api_batch._state, api_batch._state) + + class TestGetPairedSteps(unittest.TestCase): def setUp(self): self.npu_dir = tempfile.TemporaryDirectory() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py index ce7094de70..cb2fde7eb6 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py @@ -24,7 +24,7 @@ class TestFirstDiffAnalyze(unittest.TestCase): {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5]}) def test_single_metric_diff_check_true(self): mode_config = ModeConfig(first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, '') result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '60.0%') self.assertTrue(result) @@ -32,7 +32,7 @@ class TestFirstDiffAnalyze(unittest.TestCase): {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5]}) def test_single_metric_diff_check_false(self): mode_config = ModeConfig(first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, '') result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') self.assertFalse(result) @@ -40,7 +40,7 @@ class TestFirstDiffAnalyze(unittest.TestCase): {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'NormRelativeErr': [0.5]}) def test_single_metric_diff_check_miss_threshold(self): mode_config = ModeConfig(first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, '') with self.assertRaises(CompareException) as context: result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') self.assertEqual(context.exception.code, CompareException.MISSING_THRESHOLD_ERROR) @@ -49,7 +49,7 @@ class TestFirstDiffAnalyze(unittest.TestCase): {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5, 1.0]}) def test_single_metric_diff_check_wrong_threshold(self): mode_config = ModeConfig(first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, '') with self.assertRaises(CompareException) as context: result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') self.assertEqual(context.exception.code, CompareException.WRONG_THRESHOLD_ERROR) @@ -73,7 +73,7 @@ class TestFirstDiffAnalyze(unittest.TestCase): ] } mode_config = ModeConfig(first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank0') result = first_diff_analyze.single_api_check(result_slice, self.header) self.assertEqual(result, expected_result) @@ -96,7 +96,7 @@ class TestFirstDiffAnalyze(unittest.TestCase): ] } mode_config = ModeConfig(first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank0') result = first_diff_analyze.single_api_check(result_slice, self.header) self.assertEqual(result, expected_result) @@ -123,7 +123,7 @@ class TestFirstDiffAnalyze(unittest.TestCase): ] } mode_config = ModeConfig(dump_mode=Const.MD5, first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank0') result = first_diff_analyze.single_api_check(result_slice, md5_header) self.assertEqual(result, expected_result) @@ -150,7 +150,7 @@ class TestFirstDiffAnalyze(unittest.TestCase): ] } mode_config = ModeConfig(dump_mode=Const.MD5, first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank0') result = first_diff_analyze.single_api_check(result_slice, md5_header) self.assertEqual(result, expected_result) @@ -171,6 +171,6 @@ class TestFirstDiffAnalyze(unittest.TestCase): } } mode_config = ModeConfig(first_diff_analyze=True) - first_diff_analyze = FirstDiffAnalyze(mode_config) + first_diff_analyze = FirstDiffAnalyze(mode_config, 'rank1') result = first_diff_analyze.check(self.result_df) self.assertEqual(result, expected_result) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py index a3809974d1..c5190d3a57 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py @@ -470,6 +470,7 @@ class TestUtilsMethods(unittest.TestCase): def test_find_compare_result_error_rows(self): result = [line_input, line_1, line_2, line_3] result_df = pd.DataFrame(result) + result_df.columns = CompareConst.COMPARE_RESULT_HEADER + [Const.STATE, Const.API_ORIGIN_NAME] highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} mode_config = ModeConfig(dump_mode=Const.ALL) highlight = HighLight(mode_config, '') -- Gitee