diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index bdabf1de989e7b9f9eb131ef3d2d791bb115cb7e..640b6f587807898375e3cc135d31706d5e9162b0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -15,12 +15,12 @@ from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack from msprobe.mindspore.compare.layer_mapping import get_layer_mapping class MSComparator(Comparator): - def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None): + def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False): self.frame_name = MSComparator.__name__ self.cell_mapping = cell_mapping self.api_mapping = api_mapping self.data_mapping = data_mapping - self.cross_frame = cell_mapping is not None or api_mapping is not None or data_mapping is not None + self.cross_frame = is_cross_framework self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping) self.api_mapping_dict = self.load_mapping_file(self.api_mapping) if api_mapping is not None: @@ -292,6 +292,23 @@ def generate_file_mapping(npu_json_path, bench_json_path, mapping_list): return result +def check_cross_framework(bench_json_path): + bench_json_data = load_json(bench_json_path) + bench_data = bench_json_data.get("data", {}) + for _, input_output in bench_data.items(): + input_args = input_output.get("input_args", []) + for input_i in input_args: + data_name = input_i.get("data_name", "") + if data_name.endswith(".pt"): + return True + output = input_output.get("output", []) + for output_i in output: + data_name = output_i.get("data_name", "") + if data_name.endswith(".pt"): + return True + return False + + def ms_compare(input_param, output_path, **kwargs): try: stack_mode = kwargs.get('stack_mode', False) @@ -321,8 +338,8 @@ def ms_compare(input_param, output_path, **kwargs): data_mapping_name = add_time_with_yaml(f"data_mapping") data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}") save_yaml(data_mapping_path, data_mapping) - - ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping) + is_cross_framework = check_cross_framework(input_param.get("bench_json_path")) + ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework) ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, md5_compare=md5_compare)