From 40bb7526f754f88006b8d662da222fd8a9cd15e3 Mon Sep 17 00:00:00 2001 From: jijiarong Date: Tue, 3 Sep 2024 18:39:37 +0800 Subject: [PATCH] bugfix 2947 --- .../mindspore/compare/ms_graph_compare.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py index 69d95ca5e..e844a398f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py @@ -57,29 +57,32 @@ def generate_path_by_rank_step(base_path, rank_id, step_id): def statistic_data_read(statistic_file_list, statistic_file_path): data_list = [] statistic_data_list = [] + header_index = {'Data Type': None, 'Shape': None, 'Max Value': None, 'Min Value': None, + 'Avg Value': None, 'L2Norm Value': None} for statistic_file in statistic_file_list: with open(statistic_file, "r") as f: csv_reader = csv.reader(f, delimiter=",") header = next(csv_reader) - header_index = {'Data Type': None, 'Shape': None, 'Max Value': None, 'Min Value': None, - 'Avg Value': None, 'L2Norm Value': None} for key in header_index.keys(): for index, value in enumerate(header): if key == value: header_index[key] = index - for key in header_index.keys(): - if header_index[key] is None: - logger.error(f"Data_path {statistic_file_path} has no key {key}") - raise FileCheckException(f"Data_path {statistic_file_path} has no key {key}") statistic_data_list.extend([row for row in csv_reader]) + for key in header_index.keys(): + if header_index[key] is None: + logger.warning(f"Data_path {statistic_file_path} has no key {key}.") + for data in statistic_data_list: compare_key = f"{data[1]}.{data[2]}.{data[3]}.{data[5]}" timestamp = int(data[4]) - data_list.append( - [statistic_file_path, compare_key, timestamp, data[header_index['Data Type']], - data[header_index['Shape']], data[header_index['Max Value']], data[header_index['Min Value']], - data[header_index['Avg Value']], data[header_index['L2Norm Value']]]) + result_data = [statistic_file_path, compare_key, timestamp] + for key in header_index.keys(): + if header_index[key] is None: + result_data.append(np.nan) + else: + result_data.append(data[header_index[key]]) + data_list.append(result_data) return data_list -- Gitee