diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 211339f6884497ff01160c3c61bb61bd10c49aa6..3f57d095a666b5122ac9de264e8ade941bb36c5e 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -82,6 +82,8 @@ class MsprobeBaseException(Exception): INVALID_STATE_ERROR = 35 INVALID_API_NAME_ERROR = 36 CROSS_FRAME_ERROR = 37 + MISSING_THRESHOLD_ERROR = 38 + WRONG_THRESHOLD_ERROR = 38 def __init__(self, code, error_info: str = ""): super(MsprobeBaseException, self).__init__() @@ -284,6 +286,10 @@ def add_time_with_xlsx(name): return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) +def add_time_with_json(name): + return '{}_{}.json'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) + + def add_time_with_yaml(name): return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 1100cbe2136847e4db36bccafd065142714f96f3..af9a518ff4f0d38fb336defcf84e3d023bfb2daa 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -25,16 +25,18 @@ from tqdm import tqdm from msprobe.core.advisor.advisor import Advisor from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_excel +from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_excel, save_json from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \ - set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type + set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type, \ + add_time_with_json from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \ reorder_op_x_list, set_stack_json_path, check_api_info_len from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict from msprobe.core.compare.multiprocessing_compute import CompareRealData from msprobe.core.compare.highlight import HighLight +from msprobe.core.compare.diff_analyze.first_diff_analyze import FirstDiffAnalyze @dataclass @@ -50,6 +52,7 @@ class ComparisonConfig: api_mapping: dict layer_mapping: dict compared_file_type: str + first_diff_analyze: bool class Comparator: @@ -58,17 +61,18 @@ class Comparator: self.mode_config = mode_config self.mapping_config = mapping_config self.cross_frame = is_cross_framework - self.mapping_dict = MappingDict(mapping_config) - @staticmethod - def process_output_file(output_path, suffix, compared_file_type): + def process_output_file(self, output_path, suffix, compared_file_type): file_name_prefix_mapping = { Const.DUMP_JSON_FILE: "compare_result", Const.DEBUG_JSON_FILE: "debug_compare_result" } file_name_prefix = file_name_prefix_mapping.get(compared_file_type, "compare_result") - file_name = add_time_with_xlsx(file_name_prefix + suffix) + if self.mode_config.first_diff_analyze: + file_name = add_time_with_json("compare_result" + suffix) + else: + file_name = add_time_with_xlsx(file_name_prefix + suffix) file_path = os.path.join(os.path.realpath(output_path), file_name) if os.path.exists(file_path): logger.warning(f"{file_path} will be deleted.") @@ -109,6 +113,13 @@ class Comparator: logger.warning("Can`t match any op. No compare result file generated.") return + if self.mode_config.first_diff_analyze: + first_diff_analyze = FirstDiffAnalyze(self.mode_config) + check_result = first_diff_analyze.check(result_df) + save_json(file_path, check_result, indent=4) + logger.info(f"Saving json file to disk: {file_path}") + return + # compare real data if self.mode_config.dump_mode == Const.ALL: compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame) @@ -158,6 +169,8 @@ class Comparator: match_result.loc[~match.gen_dtype_condition(match_result), bench_columns] = CompareConst.N_A # organize compare result table by renaming columns + if self.mode_config.dump_mode == Const.ALL and self.mode_config.first_diff_analyze: + self.mode_config.dump_mode = Const.SUMMARY create_table = CreateTable(self.mode_config) result_df, header = create_table.make_result_df(match_result) @@ -692,7 +705,7 @@ class CalcStatsDiff: # 相对误差转成百分比字符串 cond_ref_err = cond_not_nan_diff & ~condition_pt_zero result_df.loc[cond_ref_err, rel_err_name] = ( - result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err] * 100) + result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err].astype(float) * 100) result_df.loc[cond_ref_err, rel_err_name] = (result_df.loc[cond_ref_err, rel_err_name].abs().astype(str) + '%') magnitude = self.get_number(result_df[diff_name]).abs() / (pd.Series( @@ -709,7 +722,7 @@ class CalcStatsDiff: condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5] result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF - elif self.mode_config.dump_mode == Const.SUMMARY: + elif self.mode_config.first_diff_analyze or self.mode_config.dump_mode == Const.SUMMARY: warning_list = [ self.calc_summary_diff(result_df, condition_no_bench, stats_index) for stats_index in ['max', 'min', 'mean', 'l2norm'] @@ -743,6 +756,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: cell_mapping=kwargs.get('cell_mapping', {}), api_mapping=kwargs.get('api_mapping', {}), layer_mapping=kwargs.get('layer_mapping', {}), + first_diff_analyze=kwargs.get('first_diff_analyze', False), compared_file_type='', ) diff --git a/debug/accuracy_tools/msprobe/core/compare/config.py b/debug/accuracy_tools/msprobe/core/compare/config.py index 3aa237a6ab65763a98036c776b512474aa2d31e7..71a512ea976c1a027e7cd21b3d0fdc64c2828542 100644 --- a/debug/accuracy_tools/msprobe/core/compare/config.py +++ b/debug/accuracy_tools/msprobe/core/compare/config.py @@ -26,6 +26,7 @@ class ModeConfig: self.fuzzy_match = kwargs.get('fuzzy_match', False) self.highlight = kwargs.get('highlight', False) self.dump_mode = kwargs.get('dump_mode', Const.SUMMARY) + self.first_diff_analyze = kwargs.get('first_diff_analyze', False) self.compared_file_type = kwargs.get('compared_file_type', Const.DUMP_JSON_FILE) @@ -69,4 +70,4 @@ class MappingDict: else: raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " f"{type(data_mapping)}") - return data_mapping_dict \ No newline at end of file + return data_mapping_dict diff --git a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/__init__.py b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b35d10bfb99f00a93e2fd6ad69112c6a40efce1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml @@ -0,0 +1,14 @@ +compare_metrics: + - MaxRelativeErr + - MinRelativeErr + - MeanRelativeErr + - NormRelativeErr + +MaxRelativeErr: + - 0.5 +MinRelativeErr: + - 0.5 +MeanRelativeErr: + - 0.5 +NormRelativeErr: + - 0.5 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..f1192d895190dfce472bf82b6d213a2fa081d210 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.utils import logger, CompareException +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.compare.config import ModeConfig +from msprobe.core.compare.utils import gen_api_batches + + +cur_dir = os.path.dirname(os.path.realpath(__file__)) +diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml') +thresholds = load_yaml(diff_threshold_yaml_path) +cmp_metrics = thresholds.get('compare_metrics') + + +class FirstDiffAnalyze: + def __init__(self, mode_config: ModeConfig): + self.mode_config = mode_config + + @staticmethod + def single_metric_diff_check(cmp_metric, metric_value): + threshold = thresholds.get(cmp_metric, None) + if threshold is None: + logger.error(f"Check diff or {cmp_metric} need to configure the threshold. " + f"Please configure it in 'diff_analyze_threshold.yaml'.") + raise CompareException(CompareException.MISSING_THRESHOLD_ERROR) + if not isinstance(threshold, list) or len(threshold) != 1: + logger.error(f"{cmp_metric} threshold configure wrong. Please check.") + raise CompareException(CompareException.WRONG_THRESHOLD_ERROR) + if isinstance(metric_value, str) and metric_value.endswith('%'): + metric_value_float = float(metric_value[:-1]) / 100 + if metric_value_float > threshold[0]: + return True + return False + + def single_api_check(self, result_slice, header): + """ + 单个api差异检查 + + :param result_slice: 数据切片 + :param header: 列名列表 + :return: {'is_same': bool, 'op_items': list[dict]} + """ + single_check_result = { + 'is_same': True, + 'op_items': [] + } + + column_indices = {name: idx for idx, name in enumerate(header)} + + for line in result_slice: + op_item = { + column_name: line[column_indices[column_name]] + for column_name in header + } + single_check_result['op_items'].append(op_item) + + # set is_same + if self.mode_config.dump_mode == Const.MD5: + if line[column_indices[CompareConst.RESULT]] == CompareConst.DIFF: + single_check_result['is_same'] = False + else: + for cmp_metric in cmp_metrics: + metric_value = line[column_indices[cmp_metric]] + if self.single_metric_diff_check(cmp_metric, metric_value): + single_check_result['is_same'] = False + break + return single_check_result + + def check(self, result_df): + """ + 比对后循环遍历api检查差异 + example: + { + 'Functional.conv2d.0.forward': { + 'is_same': true, + 'op_items': [ + { + 'NPU name': 'Functional.conv2d.0.forward.input.0', + 'Bench name': 'Functional.conv2d.0.forward.input.0', + 'xxx': 1, + 'NormRelativeErr': 2, + 'yyy': 3, + ... + } + ] + } + } + """ + result = result_df.values + header = result_df.columns.tolist() + + api_batches = gen_api_batches(result) + + check_result = {} + for api_batch in api_batches: + result_slice = result[api_batch.start: api_batch.params_grad_end_index] + check_result[api_batch.api_name] = self.single_api_check(result_slice, header) + + return check_result diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 9e5fe0559c933eafac923f00f0c8b9c4f18fbf96..25eea7a5b43196819a974009cc7783f6275ea190 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -16,7 +16,6 @@ import abc import math import multiprocessing -import re from collections import namedtuple import numpy as np @@ -28,8 +27,8 @@ from tqdm import tqdm from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.file_utils import save_workbook from msprobe.core.common.log import logger -from msprobe.core.common.utils import get_header_index, safe_get_value -from msprobe.core.compare.utils import table_value_is_valid, CompareException +from msprobe.core.common.utils import get_header_index +from msprobe.core.compare.utils import table_value_is_valid, gen_api_batches from msprobe.core.compare.config import ModeConfig @@ -160,65 +159,10 @@ class HighlightRules: } -class ApiBatch: - def __init__(self, api_name: str, start: int): - self.api_name = api_name - self.start = start - self.input_len = 1 # input的数量 - self.params_end_index = start + 1 # params的结束index - self.output_end_index = start + 1 # output的结束index - self.params_grad_end_index = start + 1 # params_grad的结束index - # 内部state的标志("input", "output", "parameters", "parameters_grad"), - # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index - self._state = Const.INPUT # api_batch初始化为input - - def set_state(self, state: str): - """设置当前状态""" - if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}: - self._state = state - else: - raise ValueError(f"Invalid state: {state}") - - def increment(self, state: str): - self.set_state(state) - if self._state == Const.INPUT or self._state == Const.KWARGS: - self.input_len += 1 - self.params_end_index += 1 - self.output_end_index += 1 - if self._state == Const.PARAMS: - self.params_end_index += 1 - self.output_end_index += 1 - if self._state == Const.OUTPUT: - self.output_end_index += 1 - self.params_grad_end_index += 1 - - class HighLight: def __init__(self, mode_config: ModeConfig): self.mode_config = mode_config - @staticmethod - def api_batches_update(api_batches, api_name, state, index): - """ - 当一个api的所有item更新完后,input, output的索引范围: - input: [start: start+input_len] - output: [start+input_len: output_end_index] - params: [output_end_index: params_end_index] - """ - if not api_batches: - api_batches.append(ApiBatch(api_name, index)) - else: - api_batch = api_batches[-1] - if api_batch.api_name == api_name or ( - not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name): - try: - api_batch.increment(state) - except ValueError as e: - logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}") - raise CompareException(CompareException.INVALID_STATE_ERROR) from e - else: - api_batches.append(ApiBatch(api_name, index)) - @staticmethod def check_indices_numeric(api_items, indices: list): """检查指定索引处的值是否都为数字类型(int 或 float)""" @@ -273,11 +217,7 @@ class HighLight: def find_compare_result_error_rows(self, result_df, highlight_dict): """将dataframe根据API分组,并找到有误差的算子用于高亮""" result = result_df.values - api_batches = [] - for i, res_i in enumerate(result): - api_name = safe_get_value(res_i, -1, "res_i") # 内部定义倒数第一个元素必是api_origin_name - state = safe_get_value(res_i, -2, "res_i") # 内部定义倒数第二个元素必是state - self.api_batches_update(api_batches, api_name, state, i) + api_batches = gen_api_batches(result) with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar: for api_batch in api_batches: self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 1f028fac6aa71b2850118cce7c422ee405afa52b..8acbd0b12923ebf21d75fa90d20d021d8bcd15b3 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -164,11 +164,13 @@ def gen_op_item(op_data, op_name, state): op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name op_item[Const.STATE] = state + # 补齐统计量字段 params = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM] for i in params: if i not in op_item: op_item[i] = None + # special cases if not op_item.get('dtype'): if op_item.get('type') == 'torch.Size': op_item['dtype'] = op_data.get('type') @@ -181,11 +183,18 @@ def gen_op_item(op_data, op_name, state): op_item['shape'] = '[]' for i in params: op_item[i] = op_data.get('value') + elif op_name.split(Const.SEP)[-1] in ['src', 'dst', 'group_src', 'group_dst']: + op_item['dtype'] = op_data.get('type') + op_item['shape'] = '[]' + for i in params: + op_item[i] = str(op_data.get('value')) + op_item['md5'] = str(op_data.get('value')) elif op_item.get('type') == 'torch.ProcessGroup': op_item['dtype'] = op_data.get('type') op_item['shape'] = '[]' for i in params: op_item[i] = str(op_data.get('group_ranks')) + op_item['md5'] = str(op_data.get('group_ranks')) else: op_item['dtype'] = str(type(op_data.get('value'))) op_item['shape'] = '[]' @@ -277,6 +286,61 @@ def table_value_is_valid(value: str) -> bool: return True +class ApiBatch: + def __init__(self, api_name: str, start: int): + self.api_name = api_name + self.start = start + self.input_len = 1 # input的数量 + self.params_end_index = start + 1 # params的结束index + self.output_end_index = start + 1 # output的结束index + self.params_grad_end_index = start + 1 # params_grad的结束index + # 内部state的标志("input", "output", "parameters", "parameters_grad"), + # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index + self._state = Const.INPUT # api_batch初始化为input + + def set_state(self, state: str): + """设置当前状态""" + if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}: + self._state = state + else: + raise ValueError(f"Invalid state: {state}") + + def increment(self, state: str): + self.set_state(state) + if self._state == Const.INPUT or self._state == Const.KWARGS: + self.input_len += 1 + self.params_end_index += 1 + self.output_end_index += 1 + if self._state == Const.PARAMS: + self.params_end_index += 1 + self.output_end_index += 1 + if self._state == Const.OUTPUT: + self.output_end_index += 1 + self.params_grad_end_index += 1 + + +def api_batches_update(api_batches, api_name, state, index): + """ + 当一个api的所有item更新完后,input, output的索引范围: + input: [start: start+input_len] + output: [start+input_len: output_end_index] + params: [output_end_index: params_end_index] + """ + if not api_batches: + api_batches.append(ApiBatch(api_name, index)) + else: + api_batch = api_batches[-1] + if api_batch.api_name == api_name or ( + not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name): + try: + api_batch.increment(state) + except ValueError as e: + logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}") + raise CompareException(CompareException.INVALID_STATE_ERROR) from e + else: + api_batches.append(ApiBatch(api_name, index)) + + def reorder_op_name_list(op_name_list, state_list): if not op_name_list: return op_name_list, state_list @@ -531,6 +595,15 @@ def make_result_table(result, dump_mode, stack_mode): return result_df +def gen_api_batches(result: np.ndarray): + api_batches = [] + for i, res_i in enumerate(result): + api_name = safe_get_value(res_i, -1, "res_i") # 内部定义倒数第一个元素必是api_origin_name + state = safe_get_value(res_i, -2, "res_i") # 内部定义倒数第二个元素必是state + api_batches_update(api_batches, api_name, state, i) + return api_batches + + def _compare_parser(parser): parser.add_argument("-i", "--input_path", dest="input_path", type=str, help=" The compare input path, a dict json.", required=True) @@ -556,6 +629,9 @@ def _compare_parser(parser): def compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare_func, **kwargs): + if not isinstance(kwargs.get('first_diff_analyze', False), bool): + logger.error('kwargs: first_diff_analyze should be bool, please check!') + raise CompareException(CompareException.INVALID_PARAM_ERROR) if kwargs.get('suffix'): logger.error("Argument 'suffix' is not supported for compare_distributed.") raise CompareException(CompareException.INVALID_PARAM_ERROR) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index dd948f7d61e43b83fb2801bd1ecb6a06c82c0898..b3be3e793df30ada65a98374eae4f358da433fd3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -37,6 +37,7 @@ def compare(input_param, output_path, **kwargs): 'fuzzy_match': config.fuzzy_match, 'highlight': config.highlight, 'dump_mode': config.dump_mode, + 'first_diff_analyze': config.first_diff_analyze, 'compared_file_type': config.compared_file_type } mode_config = ModeConfig(**config_dict) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..b558a20b6f592ac9ebd758a0041155beee413caa --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from msprobe.pytorch.compare.distributed_compare import compare_distributed + + +def pt_diff_analyze(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze): + compare_distributed(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze=first_diff_analyze) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index f15dbcf6e7fb1a1048ade9d44d4d5ff8b7dbaa61..173fb550067de76b77524699430691fa87df6c58 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -13,7 +13,7 @@ from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.utils import CompareException from msprobe.core.compare.utils import ApiItemInfo, _compare_parser, check_and_return_dir_contents, extract_json, \ count_struct, get_accuracy, get_rela_diff_summary_mode, merge_tensor, op_item_parse, read_op, result_item_init, \ - stack_column_process, table_value_is_valid, reorder_op_name_list, reorder_op_x_list, gen_op_item + stack_column_process, table_value_is_valid, reorder_op_name_list, reorder_op_x_list, gen_op_item, ApiBatch # test_read_op_1 op_data = { @@ -786,3 +786,85 @@ class TestGenOpItem(unittest.TestCase): expected_md5 = f"{zlib.crc32(str(op_data['value']).encode()):08x}" self.assertEqual(result['md5'], expected_md5) self.assertEqual(result['state'], 'input') + + +class TestApiBatch(unittest.TestCase): + def test_ApiBatch_increment_input(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.INPUT) + + self.assertEqual(api_batch._state, Const.INPUT) + self.assertEqual(api_batch.input_len, 2) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_output(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.OUTPUT) + + self.assertEqual(api_batch._state, Const.OUTPUT) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 3) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_kwargs(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.KWARGS) + + self.assertEqual(api_batch._state, Const.KWARGS) + self.assertEqual(api_batch.input_len, 2) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_params(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.PARAMS) + + self.assertEqual(api_batch._state, Const.PARAMS) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_multiple_input(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.INPUT) + api_batch.increment(Const.INPUT) + + self.assertEqual(api_batch._state, Const.INPUT) + self.assertEqual(api_batch.input_len, 3) + self.assertEqual(api_batch.params_end_index, 5) + self.assertEqual(api_batch.output_end_index, 5) + self.assertEqual(api_batch.params_grad_end_index, 5) + + def test_ApiBatch_increment_multiple_output(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.OUTPUT) + api_batch.increment(Const.OUTPUT) + + self.assertEqual(api_batch._state, Const.OUTPUT) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 3) + self.assertEqual(api_batch.output_end_index, 5) + self.assertEqual(api_batch.params_grad_end_index, 5) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..c7919efba7e0240ff0d3398357f2ad4d45c1df30 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py @@ -0,0 +1,170 @@ +import unittest +from unittest.mock import patch + +import pandas as pd + +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.utils import CompareException +from msprobe.core.compare.diff_analyze.first_diff_analyze import FirstDiffAnalyze +from msprobe.core.compare.config import ModeConfig + + +class TestFirstDiffAnalyze(unittest.TestCase): + def setUp(self): + self.header = ['NPU name', 'L2norm diff', + 'MaxRelativeErr', 'MinRelativeErr', 'MeanRelativeErr', 'NormRelativeErr', + 'state', 'api_origin_name'] + self.data = [ + ['Functional.conv2d.0.forward.input.0', 1, '0.0%', '0.0%', '0.0%', '0.0%', 'input', 'Functional.conv2d.0.forward'], + ['Functional.conv2d.0.forward.input.1', 1, '99.0%', '99.0%', '99.0%', '99.0%', 'input', 'Functional.conv2d.0.forward'] + ] + self.result_df = pd.DataFrame(self.data, columns=self.header) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5]}) + def test_single_metric_diff_check_true(self): + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '60.0%') + self.assertTrue(result) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5]}) + def test_single_metric_diff_check_false(self): + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertFalse(result) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'NormRelativeErr': [0.5]}) + def test_single_metric_diff_check_miss_threshold(self): + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + with self.assertRaises(CompareException) as context: + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertEqual(context.exception.code, CompareException.MISSING_THRESHOLD_ERROR) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5, 1.0]}) + def test_single_metric_diff_check_wrong_threshold(self): + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + with self.assertRaises(CompareException) as context: + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertEqual(context.exception.code, CompareException.WRONG_THRESHOLD_ERROR) + + def test_single_api_check_within_threshold(self): + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 1, '0.0%', '0.0%', '0.0%', '0.0%', 'input', 'Functional.conv2d.0.forward'], + ['Functional.conv2d.0.forward.input.1', 1, '0.1%', '0.1%', '0.1%', '0.1%', 'input', 'Functional.conv2d.0.forward'] + ] + expected_result = { + 'is_same': True, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '0.0%', 'MinRelativeErr': '0.0%', + 'MeanRelativeErr': '0.0%', 'NormRelativeErr': '0.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '0.1%', 'MinRelativeErr': '0.1%', + 'MeanRelativeErr': '0.1%', 'NormRelativeErr': '0.1%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'} + ] + } + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + result = first_diff_analyze.single_api_check(result_slice, self.header) + self.assertEqual(result, expected_result) + + def test_single_api_check_exceed_threshold(self): + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 1, '88.0%', '88.0%', '88.0%', '88.0%', 'input', 'Functional.conv2d.0.forward'], + ['Functional.conv2d.0.forward.input.1', 1, '99.0%', '99.0%', '99.0%', '99.0%', 'input', 'Functional.conv2d.0.forward'] + ] + expected_result = { + 'is_same': False, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '88.0%', 'MinRelativeErr': '88.0%', + 'MeanRelativeErr': '88.0%', 'NormRelativeErr': '88.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '99.0%', 'MinRelativeErr': '99.0%', + 'MeanRelativeErr': '99.0%', 'NormRelativeErr': '99.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + ] + } + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + result = first_diff_analyze.single_api_check(result_slice, self.header) + self.assertEqual(result, expected_result) + + def test_single_api_check_md5_same_true(self): + md5_header = CompareConst.MD5_COMPARE_RESULT_HEADER + [CompareConst.STACK, Const.STATE, Const.API_ORIGIN_NAME] + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.0', 'torch.int32', 'torch.int32', + '[]', '[]', '2144df1c', '2144df1c', 'pass', '', 'input', 'Functional.conv2d.0.forward'] + ] + expected_result = { + 'is_same': True, + 'op_items': [ + {CompareConst.NPU_NAME: 'Functional.conv2d.0.forward.input.0', + CompareConst.BENCH_NAME: 'Functional.conv2d.0.forward.input.0', + CompareConst.NPU_DTYPE: 'torch.int32', CompareConst.BENCH_DTYPE: 'torch.int32', + CompareConst.NPU_SHAPE: '[]', CompareConst.BENCH_SHAPE: '[]', + CompareConst.NPU_MD5: '2144df1c', CompareConst.BENCH_MD5: '2144df1c', + CompareConst.RESULT: 'pass', CompareConst.STACK: '', + Const.STATE: 'input', Const.API_ORIGIN_NAME: 'Functional.conv2d.0.forward' + } + ] + } + mode_config = ModeConfig(dump_mode=Const.MD5, first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + result = first_diff_analyze.single_api_check(result_slice, md5_header) + self.assertEqual(result, expected_result) + + def test_single_api_check_md5_same_false(self): + md5_header = CompareConst.MD5_COMPARE_RESULT_HEADER + [CompareConst.STACK, Const.STATE, Const.API_ORIGIN_NAME] + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.0', 'torch.int32', 'torch.int32', + '[]', '[]', '2144df1c', '2100df1c', 'Different', '', 'input', 'Functional.conv2d.0.forward'] + ] + expected_result = { + 'is_same': False, + 'op_items': [ + {CompareConst.NPU_NAME: 'Functional.conv2d.0.forward.input.0', + CompareConst.BENCH_NAME: 'Functional.conv2d.0.forward.input.0', + CompareConst.NPU_DTYPE: 'torch.int32', CompareConst.BENCH_DTYPE: 'torch.int32', + CompareConst.NPU_SHAPE: '[]', CompareConst.BENCH_SHAPE: '[]', + CompareConst.NPU_MD5: '2144df1c', CompareConst.BENCH_MD5: '2100df1c', + CompareConst.RESULT: 'Different', CompareConst.STACK: '', + Const.STATE: 'input', Const.API_ORIGIN_NAME: 'Functional.conv2d.0.forward' + } + ] + } + mode_config = ModeConfig(dump_mode=Const.MD5, first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + result = first_diff_analyze.single_api_check(result_slice, md5_header) + self.assertEqual(result, expected_result) + + def test_check_summary(self): + expected_result = { + 'Functional.conv2d.0.forward': { + 'is_same': False, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '0.0%', 'MinRelativeErr': '0.0%', + 'MeanRelativeErr': '0.0%', 'NormRelativeErr': '0.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '99.0%', 'MinRelativeErr': '99.0%', + 'MeanRelativeErr': '99.0%', 'NormRelativeErr': '99.0%', + 'state': 'input', 'api_origin_name': 'Functional.conv2d.0.forward'}, + ] + } + } + mode_config = ModeConfig(first_diff_analyze=True) + first_diff_analyze = FirstDiffAnalyze(mode_config) + result = first_diff_analyze.check(self.result_df) + self.assertEqual(result, expected_result) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py index 933c846013d9fe35d1f0a850c3ff0cc63dd5b019..5a4ca7de47f401c4eb97ad4c224dfd1ae0a92262 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py @@ -13,9 +13,10 @@ from openpyxl import load_workbook from openpyxl.styles import PatternFill from msprobe.core.common.const import CompareConst, Const -from msprobe.core.compare.highlight import ApiBatch, CheckMaxRelativeDiff, CheckOrderMagnitude, \ +from msprobe.core.compare.highlight import CheckMaxRelativeDiff, CheckOrderMagnitude, \ CheckOneThousandErrorRatio, CheckCosineSimilarity, add_highlight_row_info, HighLight from msprobe.core.compare.config import ModeConfig +from msprobe.core.compare.utils import ApiBatch summary_line_input = ['Functional_batch_norm_0_forward.input.0', 'Functional_batch_norm_0_forward.input.0', @@ -210,87 +211,6 @@ class TestUtilsMethods(unittest.TestCase): result = CheckMaxRelativeDiff().apply(info, color_columns, dump_mode=Const.SUMMARY) self.assertEqual(result, None) - def test_ApiBatch_increment_input(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.INPUT) - - self.assertEqual(api_batch._state, Const.INPUT) - self.assertEqual(api_batch.input_len, 2) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_output(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.OUTPUT) - - self.assertEqual(api_batch._state, Const.OUTPUT) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 3) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_kwargs(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.KWARGS) - - self.assertEqual(api_batch._state, Const.KWARGS) - self.assertEqual(api_batch.input_len, 2) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_params(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.PARAMS) - - self.assertEqual(api_batch._state, Const.PARAMS) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_multiple_input(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.INPUT) - api_batch.increment(Const.INPUT) - - self.assertEqual(api_batch._state, Const.INPUT) - self.assertEqual(api_batch.input_len, 3) - self.assertEqual(api_batch.params_end_index, 5) - self.assertEqual(api_batch.output_end_index, 5) - self.assertEqual(api_batch.params_grad_end_index, 5) - - def test_ApiBatch_increment_multiple_output(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.OUTPUT) - api_batch.increment(Const.OUTPUT) - - self.assertEqual(api_batch._state, Const.OUTPUT) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 3) - self.assertEqual(api_batch.output_end_index, 5) - self.assertEqual(api_batch.params_grad_end_index, 5) - - def test_find_error_rows_normal(self): compare_result = np.array([ ["Functional.linear.0.forward.input.0", "Functional.linear.0.forward.input.0", @@ -459,13 +379,6 @@ class TestUtilsMethods(unittest.TestCase): add_highlight_row_info(color_list, num, highlight_err_msg) self.assertEqual(color_list, [(1, ["a", "b"]), (5, ["c", "highlight"])]) - def test_add_highlight_row_info_new(self): - color_list = [(1, ["a", "b"]), (5, ["c"])] - num = 6 - highlight_err_msg = "highlight" - add_highlight_row_info(color_list, num, highlight_err_msg) - self.assertEqual(color_list, [(1, ["a", "b"]), (5, ["c"]), (6, ["highlight"])]) - def test_update_highlight_err_msg(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2],