diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index a86b87ce0d064bdd1019b551968ad48d9381512c..69b0fb9cd7e40a24f40bf0b24f44c7f58e97f8b9 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -440,7 +440,7 @@ class CompareConst: SUMMARY = "summary" COMPARE_RESULT = "compare_result" COMPARE_MESSAGE = "compare_message" - MAX_EXCEL_LENGTH = 1048576 + MAX_EXCEL_LENGTH = 1048500 YES = "Yes" NO = "No" STATISTICS_INDICATOR_NUM = 4 diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 5645cba6919afb36c56805b6dcce614bf7969ff2..197c1e6b15ce0589681af1c469b7aafa5593440e 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -231,9 +231,10 @@ def check_compare_param(input_param, output_path, dump_mode, stack_mode): _check_json(stack_json, input_param.get("stack_json_path")) -def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True): - arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log] - arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', 'is_print_compare_log'] +def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, highlight=True, + is_print_compare_log=True): + arg_list = [stack_mode, auto_analyze, fuzzy_match, highlight, is_print_compare_log] + arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', 'highlight', 'is_print_compare_log'] for arg, name in zip(arg_list, arg_names): if not isinstance(arg, bool): logger.error(f"Invalid input parameter, {name} which should be only bool type.") @@ -700,4 +701,3 @@ def check_process_num(process_num): raise ValueError(f"process_num({process_num}) is not a positive integer") if process_num > Const.MAX_PROCESS_NUM: raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.") - diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 7b34b73d2e054ed9c73095b21997e407a81b7232..1100cbe2136847e4db36bccafd065142714f96f3 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -25,7 +25,7 @@ from tqdm import tqdm from msprobe.core.advisor.advisor import Advisor from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import load_json, remove_path, create_directory +from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_excel from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \ set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type @@ -43,6 +43,7 @@ class ComparisonConfig: stack_mode: bool auto_analyze: bool fuzzy_match: bool + highlight: bool data_mapping: dict suffix: str cell_mapping: dict @@ -113,13 +114,20 @@ class Comparator: compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame) result_df = compare_real_data.do_multi_process(input_param, result_df) - # highlight suspicious API - highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} - highlight = HighLight(self.mode_config) - if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE: - highlight.find_compare_result_error_rows(result_df, highlight_dict) - result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘 - highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + # save result excel file + logger.info(f'Saving result excel file in progress. The file path is: {file_path}.') + if self.mode_config.highlight and len(result_df) <= CompareConst.MAX_EXCEL_LENGTH: + # highlight if not too long + highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + highlight = HighLight(self.mode_config) + if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE: + highlight.find_compare_result_error_rows(result_df, highlight_dict) + result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘 + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + else: + # fallback to simple save without highlight + result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘 + save_excel(file_path, result_df) # output compare analysis suggestions if self.mode_config.auto_analyze: @@ -729,6 +737,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: stack_mode=False, auto_analyze=kwargs.get('auto_analyze', True), fuzzy_match=kwargs.get('fuzzy_match', False), + highlight=kwargs.get('highlight', False), data_mapping=kwargs.get('data_mapping', {}), suffix=kwargs.get('suffix', ''), cell_mapping=kwargs.get('cell_mapping', {}), @@ -747,7 +756,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: else: config.stack_mode = set_stack_json_path(input_param) - check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match, + check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match, config.highlight, input_param.get('is_print_compare_log', True)) create_directory(output_path) check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode) diff --git a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py index 806540df54eb77b007671187450ac91ef989145b..59cdb3f6f83a009ce94df71a3a9935a5bc4bb049 100644 --- a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py +++ b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py @@ -32,6 +32,7 @@ def compare_cli(args): raise CompareException(CompareException.INVALID_PATH_ERROR) frame_name = args.framework auto_analyze = not args.compare_only + if frame_name == Const.PT_FRAMEWORK: from msprobe.pytorch.compare.pt_compare import compare from msprobe.pytorch.compare.distributed_compare import compare_distributed @@ -43,6 +44,7 @@ def compare_cli(args): common_kwargs = { "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match, + "highlight": args.highlight, "data_mapping": args.data_mapping, } diff --git a/debug/accuracy_tools/msprobe/core/compare/config.py b/debug/accuracy_tools/msprobe/core/compare/config.py index 448139b8b3cf545cac53a573594f7b105ddb0c41..1c127bb7cd96fd5685e79dde36167633f70caf35 100644 --- a/debug/accuracy_tools/msprobe/core/compare/config.py +++ b/debug/accuracy_tools/msprobe/core/compare/config.py @@ -20,13 +20,13 @@ from msprobe.core.common.file_utils import load_yaml class ModeConfig: - def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.SUMMARY, - compared_file_type=Const.DUMP_JSON_FILE): - self.stack_mode = stack_mode - self.auto_analyze = auto_analyze - self.fuzzy_match = fuzzy_match - self.dump_mode = dump_mode - self.compared_file_type = compared_file_type + def __init__(self, **kwargs): + self.stack_mode = kwargs.get('stack_mode', False) + self.auto_analyze = kwargs.get('auto_analyze', True) + self.fuzzy_match = kwargs.get('fuzzy_match', False) + self.highlight = kwargs.get('highlight', True) + self.dump_mode = kwargs.get('dump_mode', Const.SUMMARY) + self.compared_file_type = kwargs.get('compared_file_type', Const.DUMP_JSON_FILE) class MappingConfig: diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 37eec421690d4cb0f4591e59867885e3b877e29f..9e5fe0559c933eafac923f00f0c8b9c4f18fbf96 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -349,28 +349,19 @@ class HighLight: self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg - wb = openpyxl.Workbook() - ws = wb.active - - # write header - logger.info('Initializing Excel file.') - self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df) + wb = openpyxl.Workbook() + ws = wb.active result_df_convert = result_df.applymap(self.compare_result_df_convert) - for row in dataframe_to_rows(result_df_convert, index=False, header=True): ws.append(row) # 对可疑数据标色 logger.info('Coloring Excel in progress.') + red_fill = PatternFill(start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid") + yellow_fill = PatternFill(start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid") col_len = len(result_df.columns) - red_fill = PatternFill( - start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid" - ) - yellow_fill = PatternFill( - start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid", - ) for i in highlight_dict.get("red_rows", []): for j in range(1, col_len + 1): ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始 @@ -378,7 +369,6 @@ class HighLight: for j in range(1, col_len + 1): ws.cell(row=i + 2, column=j).fill = yellow_fill - logger.info('Saving Excel file to disk: %s' % file_path) save_workbook(wb, file_path) def handle_multi_process_malicious_value_check(self, func, result_df): diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 0f5ea4a50d0501e0f1ff142f2a34c88648acaa71..1f028fac6aa71b2850118cce7c422ee405afa52b 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -158,7 +158,7 @@ def is_leaf_data(op_data): def gen_op_item(op_data, op_name, state): op_item = {} - op_item.update(op_data) + op_item.update({key: str(value) if isinstance(value, bool) else value for key, value in op_data.items()}) data_name = op_data.get(Const.DATA_NAME) if op_data.get(Const.DATA_NAME) else '-1' # 如果是""也返回-1 op_item[Const.DATA_NAME] = data_name op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name @@ -543,6 +543,8 @@ def _compare_parser(parser): help=" Whether to give advisor.", required=False) parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true", help=" Whether to perform a fuzzy match on the api name.", required=False) + parser.add_argument("-hl", "--highlight", dest="highlight", action="store_true", + help=" Whether to set result highlighting.", required=False) parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True, help=" The cell mapping file path.", required=False) parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True, diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index 42d973a0e896dc5ee700e17f435275969eee1025..ae3dfa63d78b2b7e4553a4f68df90aa84dc362ea 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -35,8 +35,16 @@ def ms_compare(input_param, output_path, **kwargs): config.data_mapping = generate_data_mapping_by_layer_mapping(input_param, config.layer_mapping, output_path) is_cross_framework = check_cross_framework(input_param.get('bench_json_path')) - mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match, - config.dump_mode, config.compared_file_type) + + config_dict = { + 'stack_mode': config.stack_mode, + 'auto_analyze': config.auto_analyze, + 'fuzzy_match': config.fuzzy_match, + 'highlight': config.highlight, + 'dump_mode': config.dump_mode, + 'compared_file_type': config.compared_file_type + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig(config.cell_mapping, config.api_mapping, config.data_mapping) ms_comparator = Comparator(read_real_data, mode_config, mapping_config, is_cross_framework) ms_comparator.compare_core(input_param, output_path, suffix=config.suffix) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index 96e9fc88e8aa3457b44b2011732738e0d4689887..0f9c9f26a94dbf4ddda801f6d3d95b87bccb23bc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -26,8 +26,15 @@ def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tup def compare(input_param, output_path, **kwargs): config = setup_comparison(input_param, output_path, **kwargs) - mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match, - config.dump_mode, config.compared_file_type) + config_dict = { + 'stack_mode': config.stack_mode, + 'auto_analyze': config.auto_analyze, + 'fuzzy_match': config.fuzzy_match, + 'highlight': config.highlight, + 'dump_mode': config.dump_mode, + 'compared_file_type': config.compared_file_type + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig(data_mapping=config.data_mapping) pt_comparator = Comparator(read_real_data, mode_config, mapping_config) pt_comparator.compare_core(input_param, output_path, suffix=config.suffix) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index 639be036067d3fa0d1ecf93f470e48c1c4f1c6bf..1f7c515a59b5c13edcae4e890e7054d55984280f 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -354,22 +354,25 @@ class TestUtilsMethods(unittest.TestCase): 'state': ['input'] } - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': True, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) result = ParseData(mode_config).gen_merge_list(json_data, op_name, stack_json_data) self.assertEqual(result, merge_list) def test_check_op_item_fuzzy(self): - stack_mode = False - auto_analyze = True - dump_mode = Const.SUMMARY - - fuzzy_match = True - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': True, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() match = Match(mode_config, mapping_config, cross_frame=False) @@ -382,11 +385,13 @@ class TestUtilsMethods(unittest.TestCase): file_list = [os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'stack.json')] - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': True, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() from msprobe.pytorch.compare.pt_compare import read_real_data @@ -763,11 +768,13 @@ class TestMatch(unittest.TestCase): self.assertTrue(match_result.equals(expected)) def test_match_op_both_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() match = Match(mode_config, mapping_config, cross_frame=False) @@ -776,11 +783,13 @@ class TestMatch(unittest.TestCase): self.assertEqual(b, 0) def test_match_op_only_npu_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() match = Match(mode_config, mapping_config, cross_frame=False) @@ -789,11 +798,13 @@ class TestMatch(unittest.TestCase): self.assertEqual(b, 0) def test_match_op_only_bench_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() match = Match(mode_config, mapping_config, cross_frame=False) diff --git a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py index 2b7f7886535068824e782c8cfab1b6aa283198e5..cc304c8aa7a12e0335b7990e068f4679f8e35d92 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py @@ -54,7 +54,13 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False): framework: 框架类型, pytorch或mindspore is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆 """ - mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.ALL + } + mode_config = ModeConfig(**config_dict) if framework == Const.PT_FRAMEWORK: from msprobe.pytorch.compare.pt_compare import read_real_data