From 0ec9e523dc11e720fab0a36b35fcc9185fcad6e4 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Thu, 28 Aug 2025 19:38:55 +0800 Subject: [PATCH] config check support mindformers and megatron --- .../checkers/hyperparameter_checker.py | 28 +++++++++++++++---- .../config_check/resource/hyperparameter.yaml | 12 +++++++- .../utils/hyperparameter_parser.py | 2 +- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py b/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py index a1493bb3fc..fd6a5388e0 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py +++ b/debug/accuracy_tools/msprobe/core/config_check/checkers/hyperparameter_checker.py @@ -36,6 +36,20 @@ parameter_name_mapping = load_yaml(os.path.realpath(hyperparameters_path)) hyperparameters_dict = {} +def refine_json_keys(json_dcit): + new_dict = {} + for key in json_dcit.keys(): + new_key = key.split(Const.SEP)[-1].replace("-", "_") + new_dict[new_key] = key + return new_dict + + +def to_str_if_number(value): + if isinstance(value, (int, float)): + return str(value) + return value + + @register_checker_item("hyperparameter") class HyperparameterChecker(BaseChecker): target_name_in_zip = "hyperparameters" @@ -92,13 +106,17 @@ class HyperparameterChecker(BaseChecker): @staticmethod def compare_param(bench_params, cmp_params, file_name): all_diffs = [] - bench_param_names = bench_params.keys() - for bench_param_name in bench_param_names: + bench_params_refined = refine_json_keys(bench_params) + cmp_params_refined = refine_json_keys(cmp_params) + + for bench_param_name in bench_params_refined.keys(): matched_cmp_param_name, matched_with = HyperparameterChecker._fuzzy_match_parameter(bench_param_name, - cmp_params) - bench_param_value = bench_params[bench_param_name] + cmp_params_refined) + matched_cmp_param_name = cmp_params_refined.get(matched_cmp_param_name) + bench_param_name = bench_params_refined.get(bench_param_name) + bench_param_value = to_str_if_number(bench_params[bench_param_name]) if matched_cmp_param_name: - cmp_param_value = cmp_params[matched_cmp_param_name] + cmp_param_value = to_str_if_number(cmp_params[matched_cmp_param_name]) if bench_param_value != cmp_param_value: all_diffs.append( [file_name, bench_param_name, matched_cmp_param_name, bench_param_value, cmp_param_value, diff --git a/debug/accuracy_tools/msprobe/core/config_check/resource/hyperparameter.yaml b/debug/accuracy_tools/msprobe/core/config_check/resource/hyperparameter.yaml index 5cff815717..4ec150331e 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/resource/hyperparameter.yaml +++ b/debug/accuracy_tools/msprobe/core/config_check/resource/hyperparameter.yaml @@ -18,4 +18,14 @@ weight_decay: dropout_rate: - dropout - - drop_rate \ No newline at end of file + - drop_rate + +compute_dtype: + - bf16 + - fp32 + +residual_dtype: + - fp32_residual_connection + +softmax_compute_dtype: + - attention_softmax_in_fp32 diff --git a/debug/accuracy_tools/msprobe/core/config_check/utils/hyperparameter_parser.py b/debug/accuracy_tools/msprobe/core/config_check/utils/hyperparameter_parser.py index a524504c47..9e02cf5c92 100644 --- a/debug/accuracy_tools/msprobe/core/config_check/utils/hyperparameter_parser.py +++ b/debug/accuracy_tools/msprobe/core/config_check/utils/hyperparameter_parser.py @@ -96,7 +96,7 @@ class YamlParser(Parser): new_prefix = prefix + Const.SEP + key if prefix else key self.recursive_parse_parameters(value, new_prefix) elif isinstance(parameters, list): - if all(isinstance(x, (int, float, str, bool))for x in parameters): + if all(isinstance(x, (int, float, str, bool, list))for x in parameters): self.hyperparameters.update({prefix: parameters}) else: for idx, value in enumerate(parameters): -- Gitee