From 51738f48fdb0053c9577a30258cc053e80061ee9 Mon Sep 17 00:00:00 2001 From: TAJh Date: Fri, 20 Sep 2024 16:14:35 +0800 Subject: [PATCH] add modify mapping feature --- .../msprobe/core/common/const.py | 6 ++ .../msprobe/core/common/utils.py | 31 +++++++ .../msprobe/core/compare/compare_cli.py | 4 +- .../msprobe/core/compare/utils.py | 13 ++- .../mindspore/compare/modify_mapping.py | 93 +++++++++++++++++++ .../msprobe/mindspore/compare/ms_compare.py | 10 +- 6 files changed, 147 insertions(+), 10 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/mindspore/compare/modify_mapping.py diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 7eeab086e8..a50ac9f2d7 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -102,6 +102,12 @@ class Const: CPU_LOWERCASE = 'cpu' CUDA_LOWERCASE = 'cuda' DISTRIBUTED = 'Distributed' + + # struct json param + ORIGIN_DATA = "origin_data" + SCOPE = "scope" + STACK = "stack" + ATEN = "Aten" INPLACE_LIST = [ diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 9efc7b7a84..68897d9a95 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -342,6 +342,37 @@ def md5_find(data): return False +def struct_json_get(input_param, framework): + if framework == Const.PT_FRAMEWORK: + prefix = "bench" + elif framework == Const.MS_FRAMEWORK: + prefix = "npu" + else: + logger.error("Error framework found.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + + frame_json_path = input_param.get(f"{prefix}_path", None) + if not frame_json_path: + logger.error(f"Please check the json path is valid.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + + stack_json = os.path.join(frame_json_path, "stack.json") + construct_json = os.path.join(frame_json_path, "construct.json") + + if not stack_json and not construct_json: + logger.info("stack_json_path and constrcut_json_path not set.") + return {}, {} + if not stack_json or not construct_json: + logger.error("stack or construct json path not set, please check") + raise CompareException(CompareException.INVALID_PATH_ERROR) + + with FileOpen(stack_json, 'r') as stack_f: + stack = json.load(stack_f) + with FileOpen(construct_json, 'r') as construct_f: + construct = json.load(construct_f) + return stack, construct + + def task_dumppath_get(input_param): npu_path = input_param.get("npu_json_path", None) bench_path = input_param.get("bench_json_path", None) diff --git a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py index a20a17abf8..b80709cc27 100644 --- a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py +++ b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py @@ -18,7 +18,7 @@ def compare_cli(args): else: from msprobe.mindspore.compare.ms_compare import ms_compare from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare - if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE: + if (check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE) or args.layer_mapping: input_param["npu_json_path"] = input_param.pop("npu_path") input_param["bench_json_path"] = input_param.pop("bench_path") input_param["stack_json_path"] = input_param.pop("stack_path") @@ -32,6 +32,8 @@ def compare_cli(args): "fuzzy_match": args.fuzzy_match, "cell_mapping": args.cell_mapping, "api_mapping": args.api_mapping, + "data_mapping": args.data_mapping, + "layer_mapping": args.layer_mapping } ms_compare(input_param, args.output_path, **kwargs) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index b7fef54ec2..c7452fb73c 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -405,7 +405,7 @@ def merge_tensor(tensor_list, summary_compare, md5_compare): data_name = op_dict["data_name"][-1].rsplit(Const.SEP, 1)[0] if data_name != "-1": op_dict["op_name"][-1] = data_name - + if not op_dict["kwargs_struct"]: del op_dict["kwargs_struct"] @@ -426,9 +426,8 @@ def _compare_parser(parser): parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True, help=" The cell mapping file path.", required=False) parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True, - help=" The api mapping file path.", required=False) - - - - - + help=" The api mapping file path.", required=False) + parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str, + help=" The data mapping file path.", required=False) + parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, + help=" The x mapping file path.", required=False) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/modify_mapping.py b/debug/accuracy_tools/msprobe/mindspore/compare/modify_mapping.py new file mode 100644 index 0000000000..62fedc0b1c --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/modify_mapping.py @@ -0,0 +1,93 @@ +import os +import json + +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger + + +def modify_mapping_with_stack(stack, construct): + if not stack or not construct: + return {} + + # 是否是mindspore的数据结构 + is_ms = any("Cell" in ii for ii in construct) + # 调整后的mapping结构 + final_pres = {} + # 查看归属关系 + for key in construct: + key_components = key.split('.') + # 名称如果非标准开头,转为标准开头 + if key.startswith(("Functional", "Tensor", "Torch", "NPU", "Jit", "MintFunctional", "Mint", "Primitive")): + code_list = stack.get(key, None) + if not code_list: + logger.info(f"{key} not found in code stack") + continue + + if not construct.get(key, None): + if not key.startswith(("Module", "Cell")): + # 将节点名字转为标准的Module或Cell + key_components[0] = "Cell" if is_ms else "Module" + # 重复该节点的名字作为类型 如add.add add在-3位置 + if len(key_components) < 3 or key_components[-3].isdigit(): + logger.warning("key in construct.json is shorter than 3 parts or not name valid.") + duplicated_components = key_components + else: + duplicated_components = key_components[:-2] + key_components[-3:] + modified_key = '.'.join(duplicated_components) + else: + modified_key = key + modified_key = modified_key.replace(".forward", "").replace(".backward", "") + final_pres[modified_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: None, Const.STACK: None} + continue + + parent = construct[key].split('.') + if len(parent) < 4: + logger.warning(f"{construct[key]} in construct.json is not valid") + continue + # {name}.Class.count_number.X ward Or {name}.Class.count_number.X ward.ele_number + parent_idx = -4 if not parent[-4].isdigit() else -5 + + parent_name = parent[parent_idx] + if parent_name.endswith('s'): + parent_name = parent_name[:-1] + # {name}.count_number.X ward + func_name = key_components[-3] + + # 找出 start_pos 和 end_pos + start_pos = end_pos = -1 + for idx, ii in enumerate(code_list): + if func_name in ii: + start_pos = idx + elif parent_name in ii: + end_pos = idx + + # 获取指定范围的代码 + regard_scope = code_list[start_pos:end_pos] + func_skip_list = ["construct", "__call__"] + file_skip_list = ["site-packages/mindspore", "package/mindspore", "msprobe", "torch"] + res_list = [] + + # 过滤和处理 regard_scope + for line in regard_scope: + ele_list = line.split(',') + file_ele = ele_list[0] + if any(ii in file_ele for ii in file_skip_list): + continue + + func_ele = ele_list[2] + if any(ii in func_ele for ii in func_skip_list): + continue + in_func_name = func_ele.split()[1] + res_list.append(in_func_name) + + # 反转res_list并生成final_name + reversed_list = res_list[::-1] + # 组合逻辑:parent的节点名(到节点名字为止)加上调用栈名[reversed_list]加上原来key重复key的节点名[key_components[1:-2] + key_components[-3:]] + final_res_key = '.'.join(parent[:parent_idx + 1] + reversed_list + key_components[1:-2] + key_components[-3:]) + final_res_key = final_res_key.strip(".forward").strip(".backward") + else: + final_res_key = '.'.join(key_components[:-2] + [key_components[-1]]) + reversed_list = [] + final_pres[final_res_key] = {Const.ORIGIN_DATA:key, Const.SCOPE: construct[key], Const.STACK: '.'.join(reversed_list) if reversed_list else None} + return final_pres + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index 24d3d0574f..5dc4760afc 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -1,14 +1,14 @@ import os import copy from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \ - task_dumppath_get + task_dumppath_get, struct_json_get from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.log import logger from msprobe.core.common.exceptions import FileCheckException from msprobe.core.compare.acc_compare import Comparator from msprobe.core.compare.check import check_struct_match, fuzzy_check_op - +from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack class MSComparator(Comparator): def __init__(self, cell_mapping=None, api_mapping=None): @@ -206,6 +206,10 @@ def ms_compare(input_param, output_path, **kwargs): fuzzy_match = kwargs.get('fuzzy_match', False) cell_mapping = kwargs.get('cell_mapping', None) api_mapping = kwargs.get('api_mapping', None) + layer_mapping = kwargs.get('layer_mapping', None) + + pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK) + ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK) summary_compare, md5_compare = task_dumppath_get(input_param) check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True)) create_directory(output_path) @@ -213,6 +217,8 @@ def ms_compare(input_param, output_path, **kwargs): except (CompareException, FileCheckException) as error: logger.error('Compare failed. Please check the arguments and do it again!') raise CompareException(error.code) from error + pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct) + ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct) ms_comparator = MSComparator(cell_mapping, api_mapping) ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, -- Gitee