diff --git a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py index 4daf17cd78b7508991d1b495d18a8f7369348672..1eee29658141dd75d7589c60b6b1536c584cdd99 100644 --- a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py +++ b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py @@ -36,7 +36,8 @@ def compare_cli(args): "cell_mapping": args.cell_mapping, "api_mapping": args.api_mapping, "data_mapping": args.data_mapping, - "layer_mapping": args.layer_mapping + "layer_mapping": args.layer_mapping, + "framework2": args.framework2 } ms_compare(input_param, args.output_path, **kwargs) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 89a9ff5a647c675ab818f20ab1a7a674d7d278fb..1a86b46c854146e53942a047b9aadadabed7218d 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -431,3 +431,6 @@ def _compare_parser(parser): help=" The data mapping file path.", required=False) parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, help=" The layer mapping file path.", required=False) + parser.add_argument('-f2', '--framework2', required=False, choices=[Const.PT_FRAMEWORK, Const.MS_FRAMEWORK], + default=Const.PT_FRAMEWORK, help='The compared deep learning framework. ' + 'This is only valid when framework is mindspore') diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index bdabf1de989e7b9f9eb131ef3d2d791bb115cb7e..6e64476ae0250ccdc623086a46921deea38002e8 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -15,12 +15,12 @@ from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack from msprobe.mindspore.compare.layer_mapping import get_layer_mapping class MSComparator(Comparator): - def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None): + def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, framework2=Const.PT_FRAMEWORK): self.frame_name = MSComparator.__name__ self.cell_mapping = cell_mapping self.api_mapping = api_mapping self.data_mapping = data_mapping - self.cross_frame = cell_mapping is not None or api_mapping is not None or data_mapping is not None + self.cross_frame = framework2 == Const.PT_FRAMEWORK self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping) self.api_mapping_dict = self.load_mapping_file(self.api_mapping) if api_mapping is not None: @@ -301,6 +301,7 @@ def ms_compare(input_param, output_path, **kwargs): api_mapping = kwargs.get('api_mapping', None) data_mapping = kwargs.get('data_mapping', None) layer_mapping = kwargs.get('layer_mapping', None) + framework2 = kwargs.get('framework2', Const.PT_FRAMEWORK) summary_compare, md5_compare = task_dumppath_get(input_param) check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True)) @@ -322,7 +323,7 @@ def ms_compare(input_param, output_path, **kwargs): data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}") save_yaml(data_mapping_path, data_mapping) - ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping) + ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, framework2) ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, md5_compare=md5_compare)