From 2ef9460f07949ea22dd6bdaa75ef982dcb889d8e Mon Sep 17 00:00:00 2001 From: lichangwei Date: Mon, 21 Jul 2025 16:45:36 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91config=5Fcheck=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E6=96=B0=E5=A2=9Ewarning=E7=8A=B6=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/msprobe/core/common/const.py | 3 +++ .../core/config_check/checkers/dataset_checker.py | 3 ++- .../core/config_check/checkers/env_args_checker.py | 10 +++++----- .../config_check/checkers/hyperparameter_checker.py | 12 +++++++----- .../core/config_check/checkers/pip_checker.py | 7 ++++--- .../core/config_check/checkers/random_checker.py | 2 +- .../core/config_check/checkers/weights_checker.py | 3 ++- .../msprobe/core/config_check/utils/utils.py | 10 ++++++++++ .../test/core_ut/config_check/test_config_check.py | 12 ++++++------ 9 files changed, 40 insertions(+), 22 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 591279429..d408b4fb8 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -385,6 +385,9 @@ class Const: MATCH_MODE_NAME = "pure name" MATCH_MODE_MAPPING = "mapping" MATCH_MODE_SIMILARITY = "similarity" + CONFIG_CHECK_PASS = "pass" + CONFIG_CHECK_WARNING = "warning" + CONFIG_CHECK_ERROR = "error" class CompareConst: diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/dataset_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/dataset_checker.py index 96ff4809f..84af2b29a 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/checkers/dataset_checker.py +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/dataset_checker.py @@ -22,6 +22,7 @@ from msprobe.core.config_check.config_checker import register_checker_item, regi from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.common.const import Const @recursion_depth_decorator("config_check: process_obj") @@ -134,5 +135,5 @@ class DatasetChecker(BaseChecker): cmp_dataset_pack_path = os.path.join(cmp_dir, DatasetChecker.target_name_in_zip) df = compare_dataset(bench_dataset_pack_path, cmp_dataset_pack_path) - pass_check = False not in df['equal'].values + pass_check = Const.CONFIG_CHECK_PASS if False not in df['equal'].values else Const.CONFIG_CHECK_ERROR return DatasetChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/env_args_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/env_args_checker.py index d4f72a6b2..513a9f3b6 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/checkers/env_args_checker.py +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/env_args_checker.py @@ -21,7 +21,7 @@ import pandas as pd from msprobe.core.common.file_utils import load_json, load_yaml, create_file_with_content, create_file_in_zip from msprobe.core.config_check.checkers.base_checker import BaseChecker from msprobe.core.config_check.config_checker import register_checker_item -from msprobe.core.config_check.utils.utils import config_checking_print +from msprobe.core.config_check.utils.utils import config_checking_print, process_pass_check from msprobe.core.common.const import Const @@ -59,17 +59,17 @@ def compare_env_data(npu_path, bench_path): cmp_env_name = cmp_env["name"] cmp_value = cmp_data.get(cmp_env_name, value[cmp_type]["default_value"]) if not bench_env: - data.append(["only cmp has this env", cmp_env["name"], "", cmp_value, "warning"]) + data.append(["only cmp has this env", cmp_env["name"], "", cmp_value, Const.CONFIG_CHECK_WARNING]) continue bench_env_name = bench_env["name"] bench_value = bench_data.get(bench_env_name, value[bench_type]["default_value"]) if cmp_value != bench_value: - data.append([bench_env_name, cmp_env_name, bench_value, cmp_value, "error"]) + data.append([bench_env_name, cmp_env_name, bench_value, cmp_value, Const.CONFIG_CHECK_ERROR]) else: bench_env_name = bench_env["name"] bench_value = bench_data.get(bench_env_name) if bench_data.get(bench_env_name) else value[bench_type][ "default_value"] - data.append([bench_env_name, "only bench has this env", bench_value, "", "warning"]) + data.append([bench_env_name, "only bench has this env", bench_value, "", Const.CONFIG_CHECK_WARNING]) df = pd.DataFrame(data, columns=EnvArgsChecker.result_header) return df @@ -92,5 +92,5 @@ class EnvArgsChecker(BaseChecker): bench_env_data = os.path.join(bench_dir, EnvArgsChecker.target_name_in_zip) cmp_env_data = os.path.join(cmp_dir, EnvArgsChecker.target_name_in_zip) df = compare_env_data(bench_env_data, cmp_env_data) - pass_check = "error" not in df['level'].values + pass_check = process_pass_check(df['level'].values) return EnvArgsChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py index bd23f326b..a1493bb3f 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py @@ -23,7 +23,7 @@ import pandas as pd from msprobe.core.common.utils import check_extern_input_list from msprobe.core.config_check.checkers.base_checker import BaseChecker from msprobe.core.config_check.config_checker import register_checker_item -from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict +from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict, process_pass_check from msprobe.core.config_check.utils.hyperparameter_parser import ParserFactory from msprobe.core.common.file_utils import (check_file_or_directory_path, create_file_in_zip, load_json, load_yaml) @@ -86,7 +86,7 @@ class HyperparameterChecker(BaseChecker): all_diffs.extend( HyperparameterChecker.compare_param(bench_hyperparameters, cmp_hyperparameters, file_name)) df = pd.DataFrame(all_diffs, columns=HyperparameterChecker.result_header) - pass_check = "error" not in df["level"].values + pass_check = process_pass_check(df["level"].values) return HyperparameterChecker.target_name_in_zip, pass_check, df @staticmethod @@ -102,13 +102,15 @@ class HyperparameterChecker(BaseChecker): if bench_param_value != cmp_param_value: all_diffs.append( [file_name, bench_param_name, matched_cmp_param_name, bench_param_value, cmp_param_value, - matched_with, "error"]) + matched_with, Const.CONFIG_CHECK_ERROR]) del cmp_params[matched_cmp_param_name] else: all_diffs.append( - [file_name, bench_param_name, "Only in benchmark", bench_param_value, "", "", "warning"]) + [file_name, bench_param_name, "Only in benchmark", bench_param_value, "", "", + Const.CONFIG_CHECK_WARNING]) for cmp_param_name, cmp_param_value in cmp_params.items(): - all_diffs.append([file_name, "Only in comparison", cmp_param_name, "", cmp_param_value, "", "warning"]) + all_diffs.append( + [file_name, "Only in comparison", cmp_param_name, "", cmp_param_value, "", Const.CONFIG_CHECK_WARNING]) all_diffs.sort() return all_diffs diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/pip_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/pip_checker.py index a35bc3e00..0795ad6bc 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/checkers/pip_checker.py +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/pip_checker.py @@ -23,8 +23,9 @@ except ImportError: from msprobe.core.common.file_utils import load_yaml, create_file_in_zip from msprobe.core.config_check.checkers.base_checker import BaseChecker from msprobe.core.config_check.config_checker import register_checker_item -from msprobe.core.config_check.utils.utils import config_checking_print +from msprobe.core.config_check.utils.utils import config_checking_print, process_pass_check from msprobe.core.common.file_utils import FileOpen, save_excel +from msprobe.core.common.const import Const dirpath = os.path.dirname(__file__) depend_path = os.path.join(dirpath, "../resource/dependency.yaml") @@ -62,7 +63,7 @@ def compare_pip_data(bench_pip_path, cmp_pip_path, fmk): if bench_version != cmp_version: data.append([package, bench_version if bench_version else 'None', cmp_version if cmp_version else 'None', - "error"]) + Const.CONFIG_CHECK_ERROR]) df = pd.DataFrame(data, columns=PipPackageChecker.result_header) return df @@ -86,5 +87,5 @@ class PipPackageChecker(BaseChecker): bench_pip_path = os.path.join(bench_dir, PipPackageChecker.target_name_in_zip) cmp_pip_path = os.path.join(cmp_dir, PipPackageChecker.target_name_in_zip) df = compare_pip_data(bench_pip_path, cmp_pip_path, fmk) - pass_check = "error" not in df['level'].values + pass_check = process_pass_check(df['level'].values) return PipPackageChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/random_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/random_checker.py index f018922db..a91c0febb 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/checkers/random_checker.py +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/random_checker.py @@ -331,7 +331,7 @@ class RandomChecker(BaseChecker): cmp_stats_path = os.path.join(cmp_dir, RandomChecker.target_name_in_zip) df = compare_random_calls(bench_stats_path, cmp_stats_path) - pass_check = False not in df['check_result'].values + pass_check = Const.CONFIG_CHECK_PASS if False not in df['check_result'].values else Const.CONFIG_CHECK_ERROR return RandomChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/weights_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/weights_checker.py index f17c62ff9..32716ea2e 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/checkers/weights_checker.py +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/weights_checker.py @@ -22,6 +22,7 @@ from msprobe.core.config_check.checkers.base_checker import BaseChecker from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features from msprobe.core.common.framework_adapter import FmkAdp +from msprobe.core.common.const import Const def collect_weights_data(model): @@ -143,5 +144,5 @@ class WeightsChecker(BaseChecker): bench_weight_pack_path = os.path.join(bench_dir, WeightsChecker.target_name_in_zip) cmp_weight_pack_path = os.path.join(cmp_dir, WeightsChecker.target_name_in_zip) df = compare_weight(bench_weight_pack_path, cmp_weight_pack_path) - pass_check = False not in df['equal'].values + pass_check = Const.CONFIG_CHECK_PASS if False not in df['equal'].values else Const.CONFIG_CHECK_ERROR return WeightsChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/core/config_check/utils/utils.py b/debug/accuracy_tools/msprobe/core/config_check/utils/utils.py index 8e6332db8..eedcc34cf 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/utils/utils.py +++ b/debug/accuracy_tools/msprobe/core/config_check/utils/utils.py @@ -19,6 +19,7 @@ import hashlib from msprobe.core.common.framework_adapter import FmkAdp from msprobe.core.common.log import logger +from msprobe.core.common.const import Const def merge_keys(dir_0, dir_1): @@ -105,3 +106,12 @@ def update_dict(ori_dict, new_dict): ori_dict[key] = {"description": "duplicate_value", "values": [ori_dict[key], new_dict[key]]} else: ori_dict[key] = value + + +def process_pass_check(data): + if Const.CONFIG_CHECK_ERROR in data: + return Const.CONFIG_CHECK_ERROR + elif Const.CONFIG_CHECK_WARNING in data: + return Const.CONFIG_CHECK_WARNING + else: + return Const.CONFIG_CHECK_PASS diff --git a/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_config_check.py b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_config_check.py index 9fb0dc9b6..d2522cc14 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_config_check.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/config_check/test_config_check.py @@ -143,12 +143,12 @@ class TestConfigChecker(unittest.TestCase): total_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename)) self.assertEqual(total_check_result.columns.tolist(), ConfigChecker.result_header) target_total_check_result = [ - ['env', False], - ['pip', False], - ['dataset', False], - ['weights', False], - ['hyperparameters', False], - ['random', False] + ['env', "error"], + ['pip', "error"], + ['dataset', "error"], + ['weights', "error"], + ['hyperparameters', "error"], + ['random', "error"] ] self.assertEqual(total_check_result.values.tolist(), target_total_check_result) -- Gitee