From 18c4b73915dde12b81a786797488e75a9a85a640 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Wed, 16 Jul 2025 14:44:41 +0800 Subject: [PATCH] compare add parameters check --- .../msprobe/core/compare/compare_cli.py | 12 ++++++++++-- .../accuracy_tools/msprobe/mindspore/common/utils.py | 3 +++ .../msprobe/mindspore/compare/common_dir_compare.py | 4 +++- .../msprobe/mindspore/compare/ms_graph_compare.py | 9 +++++++-- .../msprobe/mindspore/compare/utils.py | 10 +++++++++- .../msprobe/pytorch/compare/pt_compare.py | 5 +++++ 6 files changed, 37 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py index 806540df54..91892a1aa8 100644 --- a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py +++ b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py @@ -14,7 +14,7 @@ # limitations under the License. import json -from msprobe.core.common.file_utils import check_file_type, load_json +from msprobe.core.common.file_utils import check_file_type, load_json, check_file_or_directory_path from msprobe.core.common.const import FileCheckConst, Const from msprobe.core.common.utils import CompareException from msprobe.core.common.log import logger @@ -22,6 +22,9 @@ from msprobe.core.common.log import logger def compare_cli(args): input_param = load_json(args.input_path) + if not isinstance(input_param, dict): + logger.error("input_param should be dict, please check!") + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) npu_path = input_param.get("npu_path", None) bench_path = input_param.get("bench_path", None) if not npu_path: @@ -47,6 +50,8 @@ def compare_cli(args): } if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE: + check_file_or_directory_path(npu_path) + check_file_or_directory_path(bench_path) input_param["npu_json_path"] = input_param.pop("npu_path") input_param["bench_json_path"] = input_param.pop("bench_path") if "stack_path" not in input_param: @@ -68,6 +73,8 @@ def compare_cli(args): } ms_compare(input_param, args.output_path, **kwargs) elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR: + check_file_or_directory_path(npu_path, isdir=True) + check_file_or_directory_path(bench_path, isdir=True) kwargs = { **common_kwargs, "stack_mode": args.stack_mode, @@ -79,7 +86,8 @@ def compare_cli(args): if input_param.get("rank_id") is not None: ms_graph_compare(input_param, args.output_path) return - if input_param.get('common', False): + common = input_param.get("common", False) + if isinstance(common, bool) and common: common_dir_compare(input_param, args.output_path) return if frame_name == Const.PT_FRAMEWORK: diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 72e09c7609..595384da41 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -95,6 +95,9 @@ def save_tensor_as_npy(tensor, file_path): def convert_to_int(value): + if isinstance(value, bool): + logger.error('The value in rank_id or step should be int, please check!') + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) try: return int(value) except Exception: diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py index ca7b383cd2..7371b12200 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py @@ -30,9 +30,10 @@ from msprobe.core.common.utils import CompareException from msprobe.core.common.exceptions import FileCheckException from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \ check_path_before_create, load_npy -from msprobe.core.common.const import CompareConst, FileCheckConst +from msprobe.core.common.const import CompareConst from msprobe.core.compare.npy_compare import compare_ops_apply from msprobe.core.compare.multiprocessing_compute import check_accuracy +from msprobe.mindspore.compare.utils import check_name_map_dict def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]: @@ -49,6 +50,7 @@ def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataF npu_root = Path(input_params.get('npu_path')) bench_root = Path(input_params.get('bench_path')) name_map_dict = input_params.get('map_dict', {}) + check_name_map_dict(name_map_dict) file_tree = build_mirror_file_tree(npu_root, bench_root) # 处理文件比对 diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py index ecf8e84d13..62b8551ae4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py @@ -168,8 +168,13 @@ class GraphMSComparator: self.output_path = output_path self.base_npu_path = input_param.get('npu_path', None) self.base_bench_path = input_param.get('bench_path', None) - self.rank_list = [convert_to_int(rank_id) for rank_id in input_param.get('rank_id', [])] - self.step_list = [convert_to_int(step_id) for step_id in input_param.get('step_id', [])] + rank_id_list = input_param.get('rank_id', []) + step_id_list = input_param.get('step_id', []) + if not isinstance(rank_id_list, list) or not isinstance(step_id_list, list): + logger.error("'rank_id' and 'step_id' should both be lists, please check!") + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) + self.rank_list = [convert_to_int(rank_id) for rank_id in rank_id_list] + self.step_list = [convert_to_int(step_id) for step_id in step_id_list] # split by rank and step, generate rank step path self.npu_rank_step_dict = self.generate_rank_step_path(self.base_npu_path) self.bench_rank_step_dict = self.generate_rank_step_path(self.base_bench_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/utils.py b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py index 7a9c78e8f7..a6f9f4ae55 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py @@ -17,7 +17,8 @@ import os from msprobe.core.common.const import Const from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst -from msprobe.core.common.utils import detect_framework_by_dump_json +from msprobe.core.common.utils import detect_framework_by_dump_json, CompareException, check_op_str_pattern_valid +from msprobe.core.common.log import logger def read_npy_data(dir_path, file_name): @@ -35,3 +36,10 @@ def read_npy_data(dir_path, file_name): def check_cross_framework(bench_json_path): framework = detect_framework_by_dump_json(bench_json_path) return framework == Const.PT_FRAMEWORK + + +def check_name_map_dict(name_map_dict): + if not isinstance(name_map_dict, dict): + logger.error("'map_dict' should be a dict, please check!") + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) + check_op_str_pattern_valid(str(name_map_dict)) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index 96e9fc88e8..b73c52dce3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from msprobe.core.common.utils import CompareException +from msprobe.core.common.log import logger from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison from msprobe.pytorch.compare.utils import read_pt_data @@ -24,6 +26,9 @@ def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tup def compare(input_param, output_path, **kwargs): + if not isinstance(input_param, dict): + logger.error("input_param should be dict, please check!") + raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR) config = setup_comparison(input_param, output_path, **kwargs) mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match, -- Gitee