diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 7b90110f29ef09ad88b0fa748f2e718f8f32c1b0..7d694c3b6bab7052869c216a5d3517e8c794bfad 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -438,7 +438,7 @@ class CompareConst: SUMMARY = "summary" COMPARE_RESULT = "compare_result" COMPARE_MESSAGE = "compare_message" - MAX_EXCEL_LENGTH = 1048576 + MAX_EXCEL_LENGTH = 1048500 YES = "Yes" NO = "No" STATISTICS_INDICATOR_NUM = 4 diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 5645cba6919afb36c56805b6dcce614bf7969ff2..197c1e6b15ce0589681af1c469b7aafa5593440e 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -231,9 +231,10 @@ def check_compare_param(input_param, output_path, dump_mode, stack_mode): _check_json(stack_json, input_param.get("stack_json_path")) -def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True): - arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log] - arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', 'is_print_compare_log'] +def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, highlight=True, + is_print_compare_log=True): + arg_list = [stack_mode, auto_analyze, fuzzy_match, highlight, is_print_compare_log] + arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', 'highlight', 'is_print_compare_log'] for arg, name in zip(arg_list, arg_names): if not isinstance(arg, bool): logger.error(f"Invalid input parameter, {name} which should be only bool type.") @@ -700,4 +701,3 @@ def check_process_num(process_num): raise ValueError(f"process_num({process_num}) is not a positive integer") if process_num > Const.MAX_PROCESS_NUM: raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.") - diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 9fd13743aba451627b681164fcd4d0bc749fb28d..04c432fee144a0bda5b6f2842e51108598718ec8 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -25,7 +25,7 @@ from tqdm import tqdm from msprobe.core.advisor.advisor import Advisor from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import load_json, remove_path, create_directory +from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_excel from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \ set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type @@ -43,6 +43,7 @@ class ComparisonConfig: stack_mode: bool auto_analyze: bool fuzzy_match: bool + highlight: bool data_mapping: dict suffix: str cell_mapping: dict @@ -113,12 +114,20 @@ class Comparator: compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame) result_df = compare_real_data.do_multi_process(input_param, result_df) - # highlight suspicious API - highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} - highlight = HighLight(self.mode_config) - if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE: - highlight.find_compare_result_error_rows(result_df, highlight_dict) - highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + # save result excel file + logger.info(f'Saving result excel file in progress. The file path is: {file_path}.') + if self.mode_config.highlight: + if len(result_df) > CompareConst.MAX_EXCEL_LENGTH: + save_excel(file_path, result_df) + else: + # highlight suspicious API + highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + highlight = HighLight(self.mode_config) + if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE: + highlight.find_compare_result_error_rows(result_df, highlight_dict) + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + else: + save_excel(file_path, result_df) # output compare analysis suggestions if self.mode_config.auto_analyze: @@ -718,6 +727,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: stack_mode=False, auto_analyze=kwargs.get('auto_analyze', True), fuzzy_match=kwargs.get('fuzzy_match', False), + highlight=kwargs.get('highlight', True), data_mapping=kwargs.get('data_mapping', {}), suffix=kwargs.get('suffix', ''), cell_mapping=kwargs.get('cell_mapping', {}), @@ -736,7 +746,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: else: config.stack_mode = set_stack_json_path(input_param) - check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match, + check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match, config.highlight, input_param.get('is_print_compare_log', True)) create_directory(output_path) check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode) diff --git a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py index 806540df54eb77b007671187450ac91ef989145b..59cdb3f6f83a009ce94df71a3a9935a5bc4bb049 100644 --- a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py +++ b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py @@ -32,6 +32,7 @@ def compare_cli(args): raise CompareException(CompareException.INVALID_PATH_ERROR) frame_name = args.framework auto_analyze = not args.compare_only + if frame_name == Const.PT_FRAMEWORK: from msprobe.pytorch.compare.pt_compare import compare from msprobe.pytorch.compare.distributed_compare import compare_distributed @@ -43,6 +44,7 @@ def compare_cli(args): common_kwargs = { "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match, + "highlight": args.highlight, "data_mapping": args.data_mapping, } diff --git a/debug/accuracy_tools/msprobe/core/compare/config.py b/debug/accuracy_tools/msprobe/core/compare/config.py index 448139b8b3cf545cac53a573594f7b105ddb0c41..1c127bb7cd96fd5685e79dde36167633f70caf35 100644 --- a/debug/accuracy_tools/msprobe/core/compare/config.py +++ b/debug/accuracy_tools/msprobe/core/compare/config.py @@ -20,13 +20,13 @@ from msprobe.core.common.file_utils import load_yaml class ModeConfig: - def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.SUMMARY, - compared_file_type=Const.DUMP_JSON_FILE): - self.stack_mode = stack_mode - self.auto_analyze = auto_analyze - self.fuzzy_match = fuzzy_match - self.dump_mode = dump_mode - self.compared_file_type = compared_file_type + def __init__(self, **kwargs): + self.stack_mode = kwargs.get('stack_mode', False) + self.auto_analyze = kwargs.get('auto_analyze', True) + self.fuzzy_match = kwargs.get('fuzzy_match', False) + self.highlight = kwargs.get('highlight', True) + self.dump_mode = kwargs.get('dump_mode', Const.SUMMARY) + self.compared_file_type = kwargs.get('compared_file_type', Const.DUMP_JSON_FILE) class MappingConfig: diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 71959d77d1ad3f3e293b103c6844d9641c9e51be..0d820e5f9d5f20b8464f4ff2d5cc7d04b507e969 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -349,28 +349,19 @@ class HighLight: self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg - wb = openpyxl.Workbook() - ws = wb.active - - # write header - logger.info('Initializing Excel file.') - self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df) + wb = openpyxl.Workbook() + ws = wb.active result_df_convert = result_df.applymap(self.compare_result_df_convert) - for row in dataframe_to_rows(result_df_convert, index=False, header=True): ws.append(row) # 对可疑数据标色 logger.info('Coloring Excel in progress.') + red_fill = PatternFill(start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid") + yellow_fill = PatternFill(start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid") col_len = len(result_df.columns) - red_fill = PatternFill( - start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid" - ) - yellow_fill = PatternFill( - start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid", - ) for i in highlight_dict.get("red_rows", []): for j in range(1, col_len + 1): ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始 @@ -378,7 +369,6 @@ class HighLight: for j in range(1, col_len + 1): ws.cell(row=i + 2, column=j).fill = yellow_fill - logger.info('Saving Excel file to disk: %s' % file_path) save_workbook(wb, file_path) def handle_multi_process_malicious_value_check(self, func, result_df): diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 1e67a8020d4af209b91fc7481c5eac4cee164a69..a37c64bcc72750598e7f74327072367550b0ada5 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -561,6 +561,8 @@ def _compare_parser(parser): help=" Whether to give advisor.", required=False) parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true", help=" Whether to perform a fuzzy match on the api name.", required=False) + parser.add_argument("-hl", "--highlight", dest="highlight", action="store_true", + help=" Whether to set result highlighting.", required=False) parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True, help=" The cell mapping file path.", required=False) parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True, diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index 42d973a0e896dc5ee700e17f435275969eee1025..ae3dfa63d78b2b7e4553a4f68df90aa84dc362ea 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -35,8 +35,16 @@ def ms_compare(input_param, output_path, **kwargs): config.data_mapping = generate_data_mapping_by_layer_mapping(input_param, config.layer_mapping, output_path) is_cross_framework = check_cross_framework(input_param.get('bench_json_path')) - mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match, - config.dump_mode, config.compared_file_type) + + config_dict = { + 'stack_mode': config.stack_mode, + 'auto_analyze': config.auto_analyze, + 'fuzzy_match': config.fuzzy_match, + 'highlight': config.highlight, + 'dump_mode': config.dump_mode, + 'compared_file_type': config.compared_file_type + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig(config.cell_mapping, config.api_mapping, config.data_mapping) ms_comparator = Comparator(read_real_data, mode_config, mapping_config, is_cross_framework) ms_comparator.compare_core(input_param, output_path, suffix=config.suffix) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index 96e9fc88e8aa3457b44b2011732738e0d4689887..0f9c9f26a94dbf4ddda801f6d3d95b87bccb23bc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -26,8 +26,15 @@ def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tup def compare(input_param, output_path, **kwargs): config = setup_comparison(input_param, output_path, **kwargs) - mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match, - config.dump_mode, config.compared_file_type) + config_dict = { + 'stack_mode': config.stack_mode, + 'auto_analyze': config.auto_analyze, + 'fuzzy_match': config.fuzzy_match, + 'highlight': config.highlight, + 'dump_mode': config.dump_mode, + 'compared_file_type': config.compared_file_type + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig(data_mapping=config.data_mapping) pt_comparator = Comparator(read_real_data, mode_config, mapping_config) pt_comparator.compare_core(input_param, output_path, suffix=config.suffix) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index 107581b5fe58446352082f0dde6c4d0e83a74246..f0b48e8ee5abd897ef83deed43c0737862a741e7 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -353,22 +353,25 @@ class TestUtilsMethods(unittest.TestCase): 'summary': [[1, 1, 1, 1]] } - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': True, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) result = ParseData(mode_config).gen_merge_list(json_data, op_name, stack_json_data) self.assertEqual(result, merge_list) def test_check_op_item_fuzzy(self): - stack_mode = False - auto_analyze = True - dump_mode = Const.SUMMARY - - fuzzy_match = True - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': True, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() match = Match(mode_config, mapping_config, cross_frame=False) @@ -381,11 +384,13 @@ class TestUtilsMethods(unittest.TestCase): file_list = [os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'stack.json')] - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': True, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() from msprobe.pytorch.compare.pt_compare import read_real_data @@ -760,11 +765,13 @@ class TestMatch(unittest.TestCase): self.assertTrue(match_result.equals(expected)) def test_match_op_both_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() match = Match(mode_config, mapping_config, cross_frame=False) @@ -773,11 +780,13 @@ class TestMatch(unittest.TestCase): self.assertEqual(b, 0) def test_match_op_only_npu_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() match = Match(mode_config, mapping_config, cross_frame=False) @@ -786,11 +795,13 @@ class TestMatch(unittest.TestCase): self.assertEqual(b, 0) def test_match_op_only_bench_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.SUMMARY, + } + mode_config = ModeConfig(**config_dict) mapping_config = MappingConfig() match = Match(mode_config, mapping_config, cross_frame=False) diff --git a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py index 2b7f7886535068824e782c8cfab1b6aa283198e5..cc304c8aa7a12e0335b7990e068f4679f8e35d92 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py @@ -54,7 +54,13 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False): framework: 框架类型, pytorch或mindspore is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆 """ - mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL) + config_dict = { + 'stack_mode': False, + 'auto_analyze': True, + 'fuzzy_match': False, + 'dump_mode': Const.ALL + } + mode_config = ModeConfig(**config_dict) if framework == Const.PT_FRAMEWORK: from msprobe.pytorch.compare.pt_compare import read_real_data