diff --git a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py index 91927f963a9170bd3ee218ff04f6302f01d9ee7c..fd45ef1488d881ebc062e3589ec29c947d59c750 100644 --- a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py +++ b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py @@ -18,12 +18,12 @@ from collections import defaultdict from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml -from msprobe.core.common.utils import (add_time_with_yaml, - detect_framework_by_dump_json, - get_stack_construct_by_dump_json_path) +from msprobe.core.common.utils import add_time_with_yaml, detect_framework_by_dump_json, \ + get_stack_construct_by_dump_json_path, CompareException from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items from msprobe.core.compare.utils import read_op, reorder_op_name_list from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.log import logger class LayerTrie: @@ -63,7 +63,11 @@ class LayerTrie: node = node.children[name] if index >= len(node.data_items[state]): return default_value - return node.data_items[state][index] + if node.data_items[state]: + return node.data_items[state][index] + else: + logger.error(f"node.data_items of state:{state} is empty, please check.") + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) def save_to_yaml(self, output_path): result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)} diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 15e18fe854e10f053d5c4c1f2f9f79800d9b1690..8362b551a1ad5c7e18b593c73c25aef1a1dc9def 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -84,8 +84,8 @@ def split_json_file(input_file, num_splits, filter_api): for file in split_files: try: remove_path(file) - except FileNotFoundError: - logger.error(f"File not found and could not be deleted: {file}") + except Exception: + logger.error(f"File not found or could not be deleted: {file}") msg = 'ERROR: Split json file failed, please check the input file and try again.' raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e return split_files, total_items