From 3425b9fb1a25793aa510e8218a2b99462b8f55eb Mon Sep 17 00:00:00 2001 From: fandawei Date: Mon, 14 Oct 2024 10:55:00 +0800 Subject: [PATCH] generate api mapping --- .../msprobe/mindspore/compare/ms_compare.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index e76ac16af4..19e5f522a1 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -320,6 +320,28 @@ def check_cross_framework(bench_json_path): return True return False +def generate_api_mapping(data_mapping): + api_mapping = {} + for key, value in data_mapping.items(): + api_key = key.rsplit('.', 2)[0] if '.' in key else key + api_value = value.rsplit('.', 2)[0] if isinstance(value, str) and '.' in value else value + api_mapping[api_key] = api_value + return api_mapping + +def get_mapping_from_layer_mapping(input_param, output_path, mapping_path): + pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK) + ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK) + mapping = load_yaml(mapping_path) + ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct) + pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct) + layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping) + data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping) + api_mapping = generate_api_mapping(data_mapping) + + data_mapping_name = add_time_with_yaml(f"data_mapping") + data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}") + save_yaml(data_mapping_path, data_mapping) + return data_mapping, api_mapping def ms_compare(input_param, output_path, **kwargs): try: @@ -339,17 +361,7 @@ def ms_compare(input_param, output_path, **kwargs): logger.error('Compare failed. Please check the arguments and do it again!') raise CompareException(error.code) from error if layer_mapping: - pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK) - ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK) - mapping = load_yaml(layer_mapping) - ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct) - pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct) - layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping) - data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping) - - data_mapping_name = add_time_with_yaml(f"data_mapping") - data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}") - save_yaml(data_mapping_path, data_mapping) + data_mapping, _ = get_mapping_from_layer_mapping(input_param, output_path, layer_mapping) is_cross_framework = check_cross_framework(input_param.get("bench_json_path")) ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework) ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, -- Gitee