diff --git a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py index 91927f963a9170bd3ee218ff04f6302f01d9ee7c..fd45ef1488d881ebc062e3589ec29c947d59c750 100644 --- a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py +++ b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py @@ -18,12 +18,12 @@ from collections import defaultdict from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml -from msprobe.core.common.utils import (add_time_with_yaml, - detect_framework_by_dump_json, - get_stack_construct_by_dump_json_path) +from msprobe.core.common.utils import add_time_with_yaml, detect_framework_by_dump_json, \ + get_stack_construct_by_dump_json_path, CompareException from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items from msprobe.core.compare.utils import read_op, reorder_op_name_list from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.log import logger class LayerTrie: @@ -63,7 +63,11 @@ class LayerTrie: node = node.children[name] if index >= len(node.data_items[state]): return default_value - return node.data_items[state][index] + if node.data_items[state]: + return node.data_items[state][index] + else: + logger.error(f"node.data_items of state:{state} is empty, please check.") + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) def save_to_yaml(self, output_path): result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)} diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 110685e9900b9557f78da537bf23dd9ac1c14b11..a60217fbd21cb2d6ce8603aad8c12b40a77722fa 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -76,8 +76,20 @@ def split_json_file(input_file, num_splits, filter_api): } } split_filename = os.path.join(input_dir, f"temp_part{i}.json") - save_json(split_filename, temp_data) split_files.append(split_filename) + try: + save_json(split_filename, temp_data) + except Exception as e: + logger.error(f"An error occurred while saving split file: {e}") + for file in split_files: + try: + remove_path(file) + except FileNotFoundError: + logger.error(f"File not found and could not be deleted: {file}") + except Exception: + logger.error(f"File not found or could not be deleted: {file}") + msg = 'ERROR: Split json file failed, please check the input file and try again.' + raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e return split_files, total_items