From 044b425b85f42311ff7d7880596129fcf9fcf310 Mon Sep 17 00:00:00 2001 From: i-robot Date: Mon, 25 Aug 2025 07:02:02 +0000 Subject: [PATCH] =?UTF-8?q?!5138=20=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91?= =?UTF-8?q?=E6=AF=94=E5=AF=B9=E6=80=A7=E8=83=BD=E4=BC=98=E5=8C=96=20Merge?= =?UTF-8?q?=20pull=20request=20!5138=20from=20yinglinwei/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/core/common/utils.py | 10 -- .../msprobe/core/compare/acc_compare.py | 33 ++-- .../msprobe/core/compare/check.py | 11 ++ .../msprobe/core/compare/highlight.py | 16 +- .../msprobe/core/compare/utils.py | 152 ++++++++++++++---- .../msprobe/test/core_ut/common/test_utils.py | 9 -- .../test/core_ut/compare/test_acc_compare.py | 20 +-- .../core_ut/compare/test_acc_compare_check.py | 27 +++- .../core_ut/compare/test_cmp_highlight.py | 30 ++-- 9 files changed, 213 insertions(+), 95 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index f4ab06070..738455fd8 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -156,16 +156,6 @@ def check_compare_param(input_param, output_path, dump_mode, stack_mode): _check_json(stack_json, input_param.get("stack_json_path")) -def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, highlight=False, - is_print_compare_log=True): - arg_list = [stack_mode, auto_analyze, fuzzy_match, highlight, is_print_compare_log] - arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', 'highlight', 'is_print_compare_log'] - for arg, name in zip(arg_list, arg_names): - if not isinstance(arg, bool): - logger.error(f"Invalid input parameter, {name} which should be only bool type.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - - def _check_json(json_file_handle, file_name): tensor_line = json_file_handle.readline() if not tensor_line: diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 77d6d6f91..2846d413a 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -28,9 +28,9 @@ from msprobe.core.common.exceptions import FileCheckException from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_excel, save_json from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \ - set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type, \ - add_time_with_json -from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping + set_dump_path, get_dump_mode, check_compare_param, load_stack_json, get_file_type, add_time_with_json +from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping, \ + check_configuration_param from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \ reorder_op_x_list, set_stack_json_path, check_api_info_len from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict @@ -53,6 +53,7 @@ class ComparisonConfig: layer_mapping: dict compared_file_type: str first_diff_analyze: bool + is_print_compare_log: bool class Comparator: @@ -100,6 +101,7 @@ class Comparator: # get kwargs or set default value suffix = kwargs.get('suffix', '') + rank = suffix[1:] # process output file file_path = self.process_output_file(output_path, suffix, self.mode_config.compared_file_type) @@ -108,7 +110,7 @@ class Comparator: npu_json = input_param.get("npu_json_path") bench_json = input_param.get("bench_json_path") stack_json = input_param.get("stack_json_path") - result_df = self.compare_statistics([npu_json, bench_json, stack_json]) + result_df = self.compare_statistics([npu_json, bench_json, stack_json], rank) if not result_df.values.tolist(): logger.warning("Can`t match any op. No compare result file generated.") return @@ -130,7 +132,7 @@ class Comparator: if self.mode_config.highlight and len(result_df) <= CompareConst.MAX_EXCEL_LENGTH: # highlight if not too long highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} - highlight = HighLight(self.mode_config) + highlight = HighLight(self.mode_config, rank) if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE: highlight.find_compare_result_error_rows(result_df, highlight_dict) result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘 @@ -147,9 +149,9 @@ class Comparator: print_compare_ends_info() - def compare_statistics(self, file_list): + def compare_statistics(self, file_list, rank): # load and parse json data - parse_data = ParseData(self.mode_config) + parse_data = ParseData(self.mode_config, rank) npu_df, bench_df = parse_data.parse(file_list) npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str) @@ -180,8 +182,9 @@ class Comparator: class ParseData: - def __init__(self, mode_config: ModeConfig): + def __init__(self, mode_config: ModeConfig, rank): self.mode_config = mode_config + self.rank = rank def parse(self, file_list): npu_json_path, bench_json_path, stack_json_path = file_list @@ -190,12 +193,12 @@ class ParseData: stack_json_data = load_stack_json(stack_json_path) if self.mode_config.stack_mode else None # parse json data and generate df - npu_df = self.gen_data_df(npu_json_data, stack_json_data) - bench_df = self.gen_data_df(bench_json_data, stack_json_data) + npu_df = self.gen_data_df(npu_json_data, stack_json_data, 'NPU') + bench_df = self.gen_data_df(bench_json_data, stack_json_data, 'Bench') return npu_df, bench_df - def gen_data_df(self, data_json, stack_json_data): + def gen_data_df(self, data_json, stack_json_data, device: str): result = { CompareConst.OP_NAME: [], Const.DTYPE: [], @@ -217,7 +220,9 @@ class ParseData: return pd.DataFrame(result) api_nums = len(apis_data) - progress_bar = tqdm(total=api_nums, desc="API/Module Read Progress", unit="api/module", ncols=100) + default_bar_desc = f'{device} API/Module Read Progress' + bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc + progress_bar = tqdm(total=api_nums, desc=bar_desc_add_rank, unit="api/module", ncols=100) # 从json中循环解析API数据,遍历所有API for data_name in apis_data: @@ -777,6 +782,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: layer_mapping=kwargs.get('layer_mapping', {}), first_diff_analyze=kwargs.get('first_diff_analyze', False), compared_file_type='', + is_print_compare_log=input_param.get('is_print_compare_log', True) ) set_dump_path(input_param) @@ -789,8 +795,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: else: config.stack_mode = set_stack_json_path(input_param) - check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match, config.highlight, - input_param.get('is_print_compare_log', True)) + check_configuration_param(config) create_directory(output_path) check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode) diff --git a/debug/accuracy_tools/msprobe/core/compare/check.py b/debug/accuracy_tools/msprobe/core/compare/check.py index acc90ec3d..78fb0d355 100644 --- a/debug/accuracy_tools/msprobe/core/compare/check.py +++ b/debug/accuracy_tools/msprobe/core/compare/check.py @@ -108,3 +108,14 @@ def check_stack_json_str(stack_info, op_name): else: logger.error(f"Expected stack_info to be a list, but got {type(stack_info).__name__} for '{op_name}'") raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) + + +def check_configuration_param(config): + arg_list = [config.stack_mode, config.auto_analyze, config.fuzzy_match, + config.highlight, config.first_diff_analyze, config.is_print_compare_log] + arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', + 'highlight', 'first_diff_analyze', 'is_print_compare_log'] + for arg, name in zip(arg_list, arg_names): + if not isinstance(arg, bool): + logger.error(f"Invalid input parameter, {name} which should be only bool type.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 71c32490b..64c599b00 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -183,8 +183,9 @@ class HighlightRules: class HighLight: - def __init__(self, mode_config: ModeConfig): + def __init__(self, mode_config: ModeConfig, rank): self.mode_config = mode_config + self.rank = rank @staticmethod def check_indices_numeric(api_items, indices: list): @@ -241,7 +242,9 @@ class HighLight: """将dataframe根据API分组,并找到有误差的算子用于高亮""" result = result_df.values api_batches = gen_api_batches(result) - with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar: + default_bar_desc = 'API/Module Analyse Progress' + bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc + with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="item", ncols=100) as progress_bar: for api_batch in api_batches: self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict) @@ -319,7 +322,7 @@ class HighLight: self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg - self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df) + self.df_malicious_value_check(result_df) wb = openpyxl.Workbook() ws = wb.active @@ -370,8 +373,11 @@ class HighLight: pool.close() pool.join() - def df_malicious_value_check(self, df_chunk, result_df_columns): - for row in df_chunk.itertuples(index=False): + def df_malicious_value_check(self, result_df): + result_df_columns = result_df.columns.tolist() + for column in result_df_columns: + self.value_check(column) + for row in result_df.itertuples(index=False): api_name = row[0] for i, value in enumerate(row): self.value_check(value, api_name, i, result_df_columns) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 5db2241e6..3a34c157e 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -18,13 +18,14 @@ import re import math import zlib from dataclasses import dataclass +import multiprocessing import numpy as np import pandas as pd from msprobe.core.common.const import Const, CompareConst, FileCheckConst from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value -from msprobe.core.common.file_utils import check_file_or_directory_path +from msprobe.core.common.file_utils import check_file_or_directory_path, load_json json_file_mapping = { Const.DUMP_JSON_FILE: "dump.json", @@ -659,41 +660,130 @@ def _compare_parser(parser): help=" The layer mapping file path.", required=False) -def compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare_func, **kwargs): - if not isinstance(kwargs.get('first_diff_analyze', False), bool): - logger.error('kwargs: first_diff_analyze should be bool, please check!') - raise CompareException(CompareException.INVALID_PARAM_ERROR) - if kwargs.get('suffix'): - logger.error("Argument 'suffix' is not supported for compare_distributed.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - is_print_compare_log = kwargs.get('is_print_compare_log', True) - # get the ranks and match by order - npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank')) - bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank')) +def get_sorted_ranks(npu_dump_dir, bench_dump_dir): + """ + get the ranks and match by order + """ + unsorted_npu_ranks = check_and_return_dir_contents(npu_dump_dir, 'rank') + unsorted_bench_ranks = check_and_return_dir_contents(bench_dump_dir, 'rank') + # 正则匹配已经校验rank后面必是数字,或者无数字的rank + npu_ranks = sorted(unsorted_npu_ranks, key=lambda x: int(x[4:]) if len(x) > 4 else -1) # 前四个字符都是rank,后面是卡号 + bench_ranks = sorted(unsorted_bench_ranks, key=lambda x: int(x[4:]) if len(x) > 4 else -1) if len(npu_ranks) != len(bench_ranks): logger.error('The number of ranks in the two runs are different. ' 'Unable to match the ranks. Please use another folder to compare ' 'or use compare() api and manually match the ranks.') raise CompareException(CompareException.INVALID_PATH_ERROR) - for nr, br in zip(npu_ranks, bench_ranks): + return npu_ranks, bench_ranks + + +def multi_statistics_compare(func, func_args): + def err_call(args): + logger.error(f'Multiprocess statistics compare failed! Reason: {args}') + try: + pool.close() + except OSError: + logger.error("Pool terminate failed") + + compare_func, input_param_nr_list, output_path, kwargs = func_args + + param_num = len(input_param_nr_list) + process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) + if param_num <= process_num: + process_num = param_num + chunks = [[input_param_nr] for input_param_nr in input_param_nr_list] + else: + chunk_size = param_num // process_num + remainder = param_num % process_num + chunks = [input_param_nr_list[i:i + chunk_size] for i in range(0, param_num - remainder, chunk_size)] + for i in range(remainder): + chunks[i].append(input_param_nr_list[param_num - remainder + i]) + + pool = multiprocessing.Pool(process_num) + for chunk in chunks: + pool.apply_async(func, args=(compare_func, chunk, output_path, kwargs), error_callback=err_call) + pool.close() + pool.join() + + +def mp_logger_init(ranks_str): + """ + 多进程比对需要对logger进行wrap和patch,在日志前加上卡号信息,从而实现不同进程日志的隔离 + """ + + def wrap_logger(fn): + def inner(msg, *args, **kwargs): + return fn(ranks_str + msg, *args, **kwargs) + return inner + + logger.info = wrap_logger(logger.info) + logger.warning = wrap_logger(logger.warning) + logger.error = wrap_logger(logger.error) + + +def multi_ranks_compare(compare_func, input_param_nr_list, output_path, kwargs): + """ + 将多卡数据分成多进程后,单进程内可能还有多张卡的数据,因此还需要多次比对 + """ + rank_list = [input_param_nr[1] for input_param_nr in input_param_nr_list] # input_param_nr内部数据结构,2元素tuple + ranks_str = f"[{' '.join(rank_list)}]" + mp_logger_init(ranks_str) + for input_param_nr in input_param_nr_list: + input_param, nr = input_param_nr + compare_entry(compare_func, input_param, output_path, nr, kwargs) + + +def compare_entry(compare_func, input_param, output_path, nr, kwargs): + try: + compare_func(input_param=input_param, output_path=output_path, suffix=f'_{nr}', **kwargs) + except CompareException as e: + if e.code == CompareException.INVALID_DATA_ERROR: + logger.error(f"Invalid or missing 'data' in dump.json. Skipping {nr} comparison.") + if e.code == CompareException.INVALID_TASK_ERROR: + logger.error(f"Invalid or missing 'task' in dump.json. Skipping {nr} comparison.") + + +def compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare_func, **kwargs): + def extract_compare_param(_file_type): npu_data_dir = os.path.join(npu_dump_dir, nr) bench_data_dir = os.path.join(bench_dump_dir, br) + npu_path = extract_json(npu_data_dir, _file_type) + bench_path = extract_json(bench_data_dir, _file_type) + if npu_path == "" or bench_path == "": + logger.debug(f'Did not find paired {_file_type} in {nr} and {br}, skip comparing.') + return {}, True + _input_param = { + 'npu_json_path': npu_path, + 'bench_json_path': bench_path, + 'is_print_compare_log': kwargs.get('is_print_compare_log', True) + } + return _input_param, False + + if kwargs.get('suffix'): + logger.error("Argument 'suffix' is not supported for compare_distributed.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + + npu_ranks, bench_ranks = get_sorted_ranks(npu_dump_dir, bench_dump_dir) + + # 统计量、md5比对 + pre_check_dump_path = os.path.join(npu_dump_dir, npu_ranks[0], 'dump.json') if npu_ranks else '' + if not pre_check_dump_path: + return + dump_data = load_json(pre_check_dump_path) + if dump_data.get('task') == Const.STATISTICS: + # dump数据为统计量或md5时,多进程加速比对 + input_param_nr_list = [] + for nr, br in zip(npu_ranks, bench_ranks): + input_param, skip = extract_compare_param(Const.DUMP_JSON_FILE) + if not skip: + input_param_nr_list.append((input_param, nr)) + func_args = (compare_func, input_param_nr_list, output_path, kwargs) + multi_statistics_compare(multi_ranks_compare, func_args) + return + + # 真实数据比对 + for nr, br in zip(npu_ranks, bench_ranks): for file_type in [Const.DUMP_JSON_FILE, Const.DEBUG_JSON_FILE]: - npu_path = extract_json(npu_data_dir, file_type) - bench_path = extract_json(bench_data_dir, file_type) - if npu_path == "" or bench_path == "": - logger.debug(f'Did not find paired {file_type} in {nr} and {br},' - ' skip comparing.') - continue - dump_result_param = { - 'npu_json_path': npu_path, - 'bench_json_path': bench_path, - 'is_print_compare_log': is_print_compare_log - } - try: - compare_func(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}', **kwargs) - except CompareException as e: - if e.code == CompareException.INVALID_DATA_ERROR: - logger.error(f"Invalid or missing 'data' in dump.json. Skipping {nr} comparison.") - if e.code == CompareException.INVALID_TASK_ERROR: - logger.error(f"Invalid or missing 'task' in dump.json. Skipping {nr} comparison.") + input_param, skip = extract_compare_param(file_type) + if not skip: + compare_entry(compare_func, input_param, output_path, nr, kwargs) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index b7b5ed9bb..551e3dad6 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -38,7 +38,6 @@ from msprobe.core.common.log import logger from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.utils import (CompareException, check_compare_param, - check_configuration_param, _check_json, check_json_file, check_regex_prefix_format_valid, @@ -134,14 +133,6 @@ class TestUtils(TestCase): self.assertEqual(len(mock__check_json.call_args[0]), 2) self.assertEqual(mock__check_json.call_args[0][1], "stack_path.json") - @patch.object(logger, "error") - def test_check_configuration_param(self, mock_error): - with self.assertRaises(CompareException) as context: - check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False, - is_print_compare_log=True) - self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_error.assert_called_with("Invalid input parameter, stack_mode which should be only bool type.") - @patch.object(logger, "error") def test__check_json(self, mock_error): class TestOpen: diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index a0d6aeee0..c00d5a061 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -363,7 +363,7 @@ class TestUtilsMethods(unittest.TestCase): } mode_config = ModeConfig(**config_dict) - result = ParseData(mode_config).gen_merge_list(json_data, op_name, stack_json_data) + result = ParseData(mode_config, 'rank0').gen_merge_list(json_data, op_name, stack_json_data) self.assertEqual(result, merge_list) def test_check_op_item_fuzzy(self): @@ -397,7 +397,7 @@ class TestUtilsMethods(unittest.TestCase): from msprobe.pytorch.compare.pt_compare import read_real_data comparator = Comparator(read_real_data, mode_config, mapping_config) - result = comparator.compare_statistics(file_list) + result = comparator.compare_statistics(file_list, 'rank0') o_data = [ ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', '[2, 2]', '[2, 2]', 'False', 'False', @@ -430,7 +430,7 @@ class TestParseData(unittest.TestCase): stack_mode = True mode_config = ModeConfig(stack_mode=stack_mode) - parse_data = ParseData(mode_config) + parse_data = ParseData(mode_config, 'rank0') npu_df, bench_df = parse_data.parse(file_list) target_df = pd.DataFrame( @@ -449,8 +449,8 @@ class TestParseData(unittest.TestCase): stack_mode = True mode_config = ModeConfig(stack_mode=stack_mode) - parse_data = ParseData(mode_config) - npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data) + parse_data = ParseData(mode_config, 'rank0') + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data, 'NPU') target_df = pd.DataFrame( [['Functional.linear.0.forward.input.0', 'torch.float32', @@ -467,8 +467,8 @@ class TestParseData(unittest.TestCase): stack_mode = True mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=Const.ALL) - parse_data = ParseData(mode_config) - npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data) + parse_data = ParseData(mode_config, 'rank0') + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data, 'NPU') target_df = pd.DataFrame( [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], @@ -485,8 +485,8 @@ class TestParseData(unittest.TestCase): stack_mode = True mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=Const.MD5) - parse_data = ParseData(mode_config) - npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data) + parse_data = ParseData(mode_config, 'rank0') + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data, 'NPU') target_df = pd.DataFrame( [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], @@ -503,7 +503,7 @@ class TestParseData(unittest.TestCase): stack_mode = True mode_config = ModeConfig(stack_mode=stack_mode) - parse_data = ParseData(mode_config) + parse_data = ParseData(mode_config, 'rank0') merge_list = parse_data.gen_merge_list(npu_json_data, 'Functional.linear.0.forward', stack_json_data) target_dict = { diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py index aefcd0f34..0d200f94a 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py @@ -1,9 +1,12 @@ # coding=utf-8 import unittest +from unittest.mock import patch from msprobe.core.compare.check import check_dump_json_str, check_json_key_value, valid_key_value, \ - check_stack_json_str + check_stack_json_str, check_configuration_param from msprobe.core.common.utils import CompareException +from msprobe.core.common.log import logger +from msprobe.core.compare.acc_compare import ComparisonConfig # test_check_struct_match @@ -124,3 +127,25 @@ class TestUtilsMethods(unittest.TestCase): with self.assertRaises(CompareException) as context: check_stack_json_str(stack_info, op_name) self.assertEqual(context.exception.code, CompareException.INVALID_CHAR_ERROR) + + @patch.object(logger, "error") + def test_check_configuration_param(self, mock_error): + config = ComparisonConfig( + dump_mode='', + stack_mode='False', + auto_analyze=True, + fuzzy_match=False, + highlight=False, + data_mapping={}, + suffix='', + cell_mapping={}, + api_mapping={}, + layer_mapping={}, + first_diff_analyze=False, + compared_file_type='', + is_print_compare_log=True + ) + with self.assertRaises(CompareException) as context: + check_configuration_param(config) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Invalid input parameter, stack_mode which should be only bool type.") \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py index aeb7308d9..a3809974d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py @@ -245,7 +245,7 @@ class TestUtilsMethods(unittest.TestCase): highlight_dict = {"red_lines": [], "red_rows": set(), "yellow_lines": [], "yellow_rows": set()} mode_config = ModeConfig(dump_mode=Const.ALL) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.find_error_rows(compare_result, api_batch, highlight_dict) self.assertEqual(highlight_dict, {"red_lines": [], "red_rows": set(), "yellow_lines": [], "yellow_rows": set()}) @@ -259,7 +259,7 @@ class TestUtilsMethods(unittest.TestCase): highlight_dict = {} mode_config = ModeConfig(dump_mode=Const.MD5) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') result = highlight.find_error_rows(compare_result, api_batch, highlight_dict) self.assertEqual(result, None) @@ -272,7 +272,7 @@ class TestUtilsMethods(unittest.TestCase): result_df_columns = CompareConst.COMPARE_RESULT_HEADER mode_config = ModeConfig() - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.value_check(value, api_name, i, result_df_columns) mock_logger.error.assert_called_once_with( @@ -289,13 +289,13 @@ class TestUtilsMethods(unittest.TestCase): result_df = pd.DataFrame(data, columns=columns) mode_config = ModeConfig(dump_mode=Const.ALL) - highlight = HighLight(mode_config) - highlight.df_malicious_value_check(result_df, columns) + highlight = HighLight(mode_config, '') + highlight.df_malicious_value_check(result_df) def test_compare_result_df_convert(self): value = float("nan") mode_config = ModeConfig() - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') result = highlight.compare_result_df_convert(value) self.assertEqual(result, "nan\t") @@ -310,7 +310,7 @@ class TestUtilsMethods(unittest.TestCase): file_path = os.path.join(base_dir, 'result.xlsx') mode_config = ModeConfig(dump_mode=Const.ALL) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) generate_result_xlsx(base_dir) @@ -327,7 +327,7 @@ class TestUtilsMethods(unittest.TestCase): file_path = os.path.join(base_dir, 'result.xlsx') mode_config = ModeConfig(dump_mode=Const.ALL) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) generate_result_xlsx(base_dir) @@ -348,7 +348,7 @@ class TestUtilsMethods(unittest.TestCase): sys.stdout = open(temp_output_file, 'w') mode_config = ModeConfig(dump_mode=Const.ALL) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) with open(temp_output_file, 'r') as f: @@ -375,7 +375,7 @@ class TestUtilsMethods(unittest.TestCase): sys.stdout = open(temp_output_file, 'w') mode_config = ModeConfig(dump_mode=Const.ALL) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) with open(temp_output_file, 'r') as f: @@ -409,7 +409,7 @@ class TestUtilsMethods(unittest.TestCase): } mode_config = ModeConfig(dump_mode=Const.ALL) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.update_highlight_err_msg(result_df, highlight_dict) t_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', @@ -431,7 +431,7 @@ class TestUtilsMethods(unittest.TestCase): highlight_dict = {} mode_config = ModeConfig(dump_mode=Const.MD5) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') result = highlight.update_highlight_err_msg(result_df, highlight_dict) self.assertEqual(result, None) @@ -450,7 +450,7 @@ class TestUtilsMethods(unittest.TestCase): 'yellow_lines': [(0, ['c']), (1, ['d'])] } mode_config = ModeConfig() - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') result = highlight.update_highlight_err_msg(result_df, highlight_dict) self.assertEqual(result, None) @@ -462,7 +462,7 @@ class TestUtilsMethods(unittest.TestCase): summary_result = [summary_line_input, summary_line_1, summary_line_2, summary_line_3] highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} mode_config = ModeConfig() - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.find_error_rows(summary_result, api_batch, highlight_dict_test) self.assertEqual(highlight_dict_test, {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}) @@ -472,7 +472,7 @@ class TestUtilsMethods(unittest.TestCase): result_df = pd.DataFrame(result) highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} mode_config = ModeConfig(dump_mode=Const.ALL) - highlight = HighLight(mode_config) + highlight = HighLight(mode_config, '') highlight.find_compare_result_error_rows(result_df, highlight_dict_test) self.assertEqual(highlight_dict_test, { "red_rows": {1, 3}, -- Gitee