diff --git a/.gitignore b/.gitignore index c70c40e0f527c8c20a6bf994bcb8070b95e13e27..2f15a00811101c8743f981fecb6976c7066fb941 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +.idea # C extensions *.so diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 7eeab086e8bbcd5ba9714a9c52b1229f1fed00f3..ab7fb95ec080d905f94dd53aef87b8dadde18195 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -102,8 +102,24 @@ class Const: CPU_LOWERCASE = 'cpu' CUDA_LOWERCASE = 'cuda' DISTRIBUTED = 'Distributed' + + # struct json param + ORIGIN_DATA = "origin_data" + SCOPE = "scope" + STACK = "stack" + ATEN = "Aten" + FUNC_SKIP_LIST = ["construct", "__call__"] + + FILE_SKIP_LIST = ["site-packages/mindspore", "package/mindspore", "msprobe", "site-packages/torch", "package/torch"] + + STACK_FILE_INDEX = 0 + + STACK_FUNC_INDEX = 1 + + CONSTRUCT_NAME_INDEX = -3 + INPLACE_LIST = [ "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all", diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 9efc7b7a84857417dbd985aa7f5f2c8281f04772..21411bde464759902683f496a26a0d417d1f7772 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -22,7 +22,7 @@ import time import json from datetime import datetime, timezone -from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path) +from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json) from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.log import logger @@ -342,6 +342,35 @@ def md5_find(data): return False +def struct_json_get(input_param, framework): + if framework == Const.PT_FRAMEWORK: + prefix = "bench" + elif framework == Const.MS_FRAMEWORK: + prefix = "npu" + else: + logger.error("Error framework found.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + + frame_json_path = input_param.get(f"{prefix}_json_path", None) + if not frame_json_path: + logger.error(f"Please check the json path is valid.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + input_param[f"{prefix}_json_path"] = os.path.join(frame_json_path, "dump.json") + stack_json = os.path.join(frame_json_path, "stack.json") + construct_json = os.path.join(frame_json_path, "construct.json") + + if not stack_json and not construct_json: + logger.info("stack_json_path and constrcut_json_path not found.") + return {}, {} + if not stack_json or not construct_json: + logger.error("stack or construct json path not found, please check") + raise CompareException(CompareException.INVALID_PATH_ERROR) + + stack = load_json(stack_json) + construct = load_json(construct_json) + return stack, construct + + def task_dumppath_get(input_param): npu_path = input_param.get("npu_json_path", None) bench_path = input_param.get("bench_json_path", None) diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index edf59a813865851101f5b1980b2cf02f9cb913f9..5032f314437c72c8c4bd8425890d09fbdb8ff498 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -177,7 +177,137 @@ class Comparator: result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode) return result_df - + + def merge_data(self, json_data, stack_json_data, summary_compare, md5_compare): + ops_all = {} + for op_name in json_data.get('data', {}): + merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, summary_compare, + md5_compare) + if merge_list: + input_index, kwarg_index, output_index = 0, 0, 0 + for index, input_or_output in enumerate(merge_list['op_name']): + data_name = merge_list.get('data_name') + data_name = data_name[index] if data_name else None + if Const.INPUT in input_or_output: + ops_all[input_or_output] = {'struct': merge_list.get('input_struct')[input_index], + 'summary': merge_list.get('summary')[index], + 'data_name': data_name, + 'stack_info': merge_list.get('stack_info')} + input_index += 1 + + elif 'kwarg' in input_or_output: + ops_all[input_or_output] = {'struct': merge_list.get('kwargs_struct')[kwarg_index], + 'summary': merge_list.get('summary')[index], + 'data_name': data_name, + 'stack_info': merge_list.get('stack_info')} + kwarg_index += 1 + + elif Const.OUTPUT in input_or_output: + ops_all[input_or_output] = {'struct': merge_list.get('output_struct')[output_index], + 'summary': merge_list.get('summary')[index], + 'data_name': data_name, + 'stack_info': merge_list.get('stack_info')} + output_index += 1 + return ops_all + + def get_accuracy(self, npu_ops_all, bench_ops_all, summary_compare, md5_compare): + result = [] + for ms_op_name, bench_op_name in self.data_mapping_dict.items(): + if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all: + + err_msg = "" + npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None) + bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None) + has_stack = npu_stack_info and bench_stack_info + + if md5_compare: + result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0], + bench_ops_all.get(bench_op_name).get('struct')[0], + npu_ops_all.get(ms_op_name).get('struct')[1], + bench_ops_all.get(bench_op_name).get('struct')[1], + npu_ops_all.get(ms_op_name).get('struct')[2], + bench_ops_all.get(bench_op_name).get('struct')[2], + CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2] + == bench_ops_all.get(bench_op_name).get('struct')[2] + else CompareConst.DIFF] + if has_stack: + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + result.append(result_item) + continue + + if summary_compare: + result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0], + bench_ops_all.get(bench_op_name).get('struct')[0], + npu_ops_all.get(ms_op_name).get('struct')[1], + bench_ops_all.get(bench_op_name).get('struct')[1], + " ", " ", " ", " ", " ", " ", " ", " "] + else: + result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0], + bench_ops_all.get(bench_op_name).get('struct')[0], + npu_ops_all.get(ms_op_name).get('struct')[1], + bench_ops_all.get(bench_op_name).get('struct')[1], + " ", " ", " ", " ", " "] + + npu_summary_data = npu_ops_all.get(ms_op_name).get("summary") + result_item.extend(npu_summary_data) + bench_summary_data = bench_ops_all.get(bench_op_name).get("summary") + result_item.extend(bench_summary_data) + + if summary_compare: + start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF) + warning_flag = False + for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)): + if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)): + diff = npu_val - bench_val + if bench_val != 0: + relative = str(abs((diff / bench_val) * 100)) + '%' + else: + relative = "N/A" + result_item[start_idx + i] = diff + result_item[start_idx + i + 4] = relative + magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10) + if magnitude_diff > 0.5: + warning_flag = True + else: + result_item[start_idx + i] = CompareConst.NONE + accuracy_check = CompareConst.WARNING if warning_flag else "" + err_msg += "Need double check api accuracy." if warning_flag else "" + for i in range(start_idx, len(result_item)): + if str(result_item[i]) in ('inf', '-inf', 'nan'): + result_item[i] = f'{result_item[i]}\t' + result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES) + result_item.append(err_msg) + if has_stack: + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + + if not (summary_compare or md5_compare): + result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None)) + + result.append(result_item) + elif ms_op_name not in npu_ops_all: + logger.warning(f'Can not find mindspore op name : `{ms_op_name}` in mindspore dump json file.') + elif bench_op_name not in npu_ops_all: + logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.') + + return result + + def compare_process_custom(self, file_handles, stack_mode, summary_compare=False, md5_compare=False): + npu_json_handle, bench_json_handle, stack_json_handle = file_handles + npu_json_data = json.load(npu_json_handle) + bench_json_data = json.load(bench_json_handle) + stack_json_data = json.load(stack_json_handle) + + npu_ops_all = self.merge_data(npu_json_data, stack_json_data, summary_compare, md5_compare) + bench_ops_all = self.merge_data(bench_json_data, stack_json_data, summary_compare, md5_compare) + + result = self.get_accuracy(npu_ops_all, bench_ops_all, summary_compare, md5_compare) + result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode) + return result_df + def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param): npu_bench_name_list = op_name_mapping_dict[npu_op_name] data_name = npu_bench_name_list[1] @@ -255,8 +385,16 @@ class Comparator: with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \ FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \ FileOpen(input_parma.get("stack_json_path"), "r") as stack_json: - result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, - summary_compare, md5_compare) + if self.data_mapping: + result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode, + summary_compare, md5_compare) + else: + result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, + summary_compare, md5_compare) + + if not result_df.values.tolist(): + logger.warning("Can`t match any op.") + return if not md5_compare and not summary_compare: result_df = self._do_multi_process(input_parma, result_df) diff --git a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py index a20a17abf80db8d77b2d74a2bae074e2ab3f40f1..c042fdf8897df1a37806bf6824b78d220043b71e 100644 --- a/debug/accuracy_tools/msprobe/core/compare/compare_cli.py +++ b/debug/accuracy_tools/msprobe/core/compare/compare_cli.py @@ -18,13 +18,16 @@ def compare_cli(args): else: from msprobe.mindspore.compare.ms_compare import ms_compare from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare - if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE: + if (check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE) or args.layer_mapping: input_param["npu_json_path"] = input_param.pop("npu_path") input_param["bench_json_path"] = input_param.pop("bench_path") input_param["stack_json_path"] = input_param.pop("stack_path") if frame_name == Const.PT_FRAMEWORK: + kwargs = { + "data_mapping": args.data_mapping + } compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze, - fuzzy_match=args.fuzzy_match) + fuzzy_match=args.fuzzy_match, **kwargs) else: kwargs = { "stack_mode": args.stack_mode, @@ -32,6 +35,8 @@ def compare_cli(args): "fuzzy_match": args.fuzzy_match, "cell_mapping": args.cell_mapping, "api_mapping": args.api_mapping, + "data_mapping": args.data_mapping, + "layer_mapping": args.layer_mapping } ms_compare(input_param, args.output_path, **kwargs) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index b7fef54ec2a93518fce53ca0f1678030b56f01e1..ba03bcae6dc076bf9c9269593573c27f2506fffc 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -405,7 +405,7 @@ def merge_tensor(tensor_list, summary_compare, md5_compare): data_name = op_dict["data_name"][-1].rsplit(Const.SEP, 1)[0] if data_name != "-1": op_dict["op_name"][-1] = data_name - + if not op_dict["kwargs_struct"]: del op_dict["kwargs_struct"] @@ -426,9 +426,8 @@ def _compare_parser(parser): parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True, help=" The cell mapping file path.", required=False) parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True, - help=" The api mapping file path.", required=False) - - - - - + help=" The api mapping file path.", required=False) + parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str, + help=" The data mapping file path.", required=False) + parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, + help=" The layer mapping file path.", required=False) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/layer_mapping.py b/debug/accuracy_tools/msprobe/mindspore/compare/layer_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..dc098de7ee18e163485e96e4ea896ad7e4b1f46b --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/layer_mapping.py @@ -0,0 +1,146 @@ +import re + + +class Trie: + def __init__(self, type_name=None, has_data=False): + self.type_name = type_name + self.call_count_list = [] + self.children = {} + self.has_data = has_data + self.type = None + + def __repr__(self): + return (f"Node(type_name={self.type_name}, " + f"has_data={self.has_data}, call number={len(self.call_count_list)})") + + def insert(self, word, word_type="func"): + parts = word.split('.') + """ + xxx, node_name, type_name, execute_num + etc: Cell.network_with_loss.language_model.encoder.layers.1.attention.out_proj.RowParallelLinear.1 + prefix_name_list: Cell.network_with_loss.language_model.encoder.layers.1.attention + node_name: out_proj + type_name: RowParallelLinear + call_count: 1 + """ + type_name = parts[-2] + call_count = parts[-1] + node = self + prefix_name_list = parts[:-2] + + for name in prefix_name_list: + if name not in node.children: + node.children[name] = Trie() + node = node.children[name] + if node.type_name is None: + node.type_name = name + + node.type_name = type_name + node.has_data = True + node.call_count_list.append(call_count) + node.type = word_type + + +def dfs_traverse_and_collect_names(node, path="", result=None): + if result is None: + result = [] + + if node.has_data: + for count in node.call_count_list: + result.append(f"{path}.{node.type_name}.{count}") + + for child_name, child_node in node.children.items(): + new_path = f"{path}.{child_name}" if path else child_name + dfs_traverse_and_collect_names(child_node, new_path, result) + + return result + + +def dfs_traverse_and_collect_converted_names(node, mapping, path="", mapping_path="", result=None): + if result is None: + result = {} + if node is None: + return result + + type_name = node.type_name + if node.has_data: + for count in node.call_count_list: + origin_name = f"{path}.{count}" if node.type == "Cell" else f"{path}.{type_name}.{count}" + mapping_name = f"{mapping_path}.{count}" if node.type == "Cell" else f"{mapping_path}.{type_name}.{count}" + result[origin_name] = mapping_name + + name_mapping = mapping.get(type_name, {}) + + for child_name, child_node in node.children.items(): + new_path = f"{path}.{child_name}" if path else child_name + converted_name = name_mapping.get(child_name, child_name) + new_mapping_path = f"{mapping_path}.{converted_name}" if mapping_path else converted_name + dfs_traverse_and_collect_converted_names( + child_node, mapping, new_path, new_mapping_path, result) + + return result + + +def get_mapping_list(ms_tree, mapping): + ms_pt_mapping = dfs_traverse_and_collect_converted_names(ms_tree, mapping=mapping) + mapping_list = [] + for ms_name, pt_name in ms_pt_mapping.items(): + pt_name = re.sub(r"^Cell", "Module", pt_name) + mapping_list.append((ms_name, pt_name)) + return mapping_list + + +def get_prefix_mapping(scope_list): + """layer name to layer name.class_name""" + layer_mapping = {} + for name, v in scope_list.items(): + origin_data = v.get("origin_data") + if not origin_data.startswith(("Cell", "Module")): + continue + name_list = name.split('.') + prefix_name_list = name_list[:-2] + [name_list[-1]] + prefix_name = '.'.join(prefix_name_list) + layer_mapping[prefix_name] = name + return layer_mapping + + +def get_layer_mapping(ms_scope_list, pt_scope_list, mapping): + + # 1. get layer prefix to full name mapping + # ect: Cell.network_with_loss.language_model.embedding.3 : Cell.network_with_loss.language_model.embedding.Embedding.3 + ms_prefix2fullname = get_prefix_mapping(ms_scope_list) + + # 2. build trie tree + ms_tree = Trie(type_name="Cell") + for k, r in ms_scope_list.items(): + origin_data_name = r.get('origin_data') + data_type = origin_data_name.split('.')[0] + ms_tree.insert(k, data_type) + msname2ptname = get_mapping_list(ms_tree, mapping) + + # 3. get pt layer prefix to full name mapping + # ect: Module.network_with_loss.language_model.embedding.3 : Module.network_with_loss.language_model.embedding.Embedding.3 + pt_prefix2fullname = get_prefix_mapping(pt_scope_list) + + final_mapping = [] + for ms_name, pt_name in msname2ptname: + final_ms_name = ms_name + final_pt_name = pt_name + # cell + if ms_name in ms_prefix2fullname: + final_ms_name = ms_prefix2fullname.get(ms_name) + final_pt_name = pt_prefix2fullname.get(pt_name, None) + # func + elif final_ms_name in ms_scope_list: + final_ms_name = ms_scope_list.get(ms_name)['origin_data'] + # remove forward/backward + final_ms_name = '.'.join(final_ms_name.split('.')[:-1]) + final_pt_name = pt_scope_list.get(pt_name, None) + if final_pt_name: + final_pt_name = final_pt_name['origin_data'] + final_pt_name = '.'.join(final_pt_name.split('.')[:-1]) + else: + continue + final_mapping.append((final_ms_name, final_pt_name)) + + return final_mapping diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/modify_mapping.py b/debug/accuracy_tools/msprobe/mindspore/compare/modify_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba70c3b5c4f820efce59cec9d4397e197f0addf --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/modify_mapping.py @@ -0,0 +1,107 @@ +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger + +def find_regard_scope(lines, start_sign, end_sign): + # 找出 start_pos 和 end_pos + start_pos = end_pos = -1 + for idx, ii in enumerate(lines): + if func_name in ii: + start_pos = idx + elif parent_name in ii: + end_pos = idx + return start_pos, end_pos + + +def find_stack_func_list(lines): + res_list = [] + # 过滤和处理 regard_scope + for line in lines: + ele_list = line.split(',') + file_ele = ele_list[Const.STACK_FILE_INDEX] + if any(ii in file_ele for ii in Const.FILE_SKIP_LIST): + continue + + func_ele = ele_list[Const.STACK_FUNC_INDEX] + if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST): + continue + + in_func_name = func_ele.split()[Const.STACK_FUNC_INDEX] + + res_list.append(in_func_name) + # 反转res_list并生成final_res + reversed_list = res_list[::-1] + return reversed_list + + +def get_duplicated_name(components): + duplicated_components = components + if len(components) < 3 or components[Const.CONSTRUCT_NAME_INDEX].isdigit(): + logger.warning("key in construct.json is shorter than 3 parts or not name valid.") + else: + # 重复name,如Functional.add.add.X ward + duplicated_components = components[:CONSTRUCT_NAME_INDEX + 1] + components[CONSTRUCT_NAME_INDEX:] + return duplicated_components + + +def modify_mapping_with_stack(stack, construct): + if not stack or not construct: + return {} + + # 是否是mindspore的数据结构 + is_ms = any("Cell" in ii for ii in construct) + # 调整后的mapping结构 + final_pres = {} + # 查看归属关系 + for key in construct: + key_components = key.split('.') + # 名称如果非标准开头,转为标准开头 + if not key.startswith(("Module", "Cell")): + code_list = stack.get(key, None) + if not code_list: + logger.info(f"Key not found in code stack") + continue + + # 如果没有拿到父属scope name,默认顶级域名为Module或Cell + if not construct.get(key, None): + if not key.startswith(("Module", "Cell")): + # 将节点名字转为标准的Module或Cell + key_components[0] = "Cell" if is_ms else "Module" + # 重复该节点的名字作为类型 如add.add add在-3位置 + duplicated_components = get_duplicated_name(key_components) + modified_key = '.'.join(duplicated_components) + else: + modified_key = key + modified_key = modified_key.replace(".forward", "").replace(".backward", "") + final_pres[modified_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: None, Const.STACK: None} + continue + + parent = construct[key].split('.') + if len(parent) < 4: + logger.warning(f"Parent name in construct.json is not valid") + continue + # {name}.Class.count_number.X ward Or {name}.Class.count_number.X ward.ele_number + parent_idx = -4 if not parent[-4].isdigit() else -5 + + parent_name = parent[parent_idx] + if parent_name.endswith('s'): + parent_name = parent_name[:-1] + # {name}.count_number.X ward + func_name = key_components[-3] + + start_pos, end_pos = find_regard_scope(code_list, func_name, parent_name) + + # 获取指定范围的代码 + regard_scope = code_list[start_pos:end_pos] + + func_stack_list = find_stack_func_list(regard_scope) + + # 组合逻辑:parent的节点名(到节点名字为止)加上调用栈名[reversed_list]加上原来key重复key的节点名[key_components[1:-2] + key_components[-3:]] + final_res_key = '.'.join(parent[:parent_idx + 1] + reversed_list + + key_components[1:Const.CONSTRUCT_NAME_INDEX + 1] + key_components[Const.CONSTRUCT_NAME_INDEX:]) + final_res_key = final_res_key.strip(".forward").strip(".backward") + else: + final_res_key = '.'.join(key_components[:-2] + [key_components[-1]]) + reversed_list = [] + final_pres[final_res_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: construct[key], + Const.STACK: '.'.join(reversed_list) if reversed_list else None} + return final_pres diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index 24d3d0574ffdafddba859744944a50dd92d7d222..40f3952913c2871dcc78e1dd5e0178be3db6a618 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -1,26 +1,38 @@ import os import copy +from itertools import zip_longest + from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \ - task_dumppath_get -from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy + task_dumppath_get, struct_json_get +from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy, load_json from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.log import logger from msprobe.core.common.exceptions import FileCheckException from msprobe.core.compare.acc_compare import Comparator from msprobe.core.compare.check import check_struct_match, fuzzy_check_op - +from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack +from msprobe.mindspore.compare.layer_mapping import get_layer_mapping class MSComparator(Comparator): - def __init__(self, cell_mapping=None, api_mapping=None): + def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None): self.frame_name = MSComparator.__name__ self.cell_mapping = cell_mapping self.api_mapping = api_mapping - self.cross_frame = cell_mapping is not None or api_mapping is not None + self.data_mapping = data_mapping + self.cross_frame = cell_mapping is not None or api_mapping is not None or data_mapping is not None self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping) self.api_mapping_dict = self.load_mapping_file(self.api_mapping) if api_mapping is not None: self.ms_to_pt_mapping = self.load_internal_api() - + + if isinstance(self.data_mapping, str) or self.data_mapping is None: + self.data_mapping_dict = self.load_mapping_file(self.data_mapping) + elif isinstance(self.data_mapping, dict): + self.data_mapping_dict = self.data_mapping + else: + raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " + f"{type(self.data_mapping)}") + def load_internal_api(self): cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml") @@ -72,7 +84,7 @@ class MSComparator(Comparator): if load_pt_file: import torch from msprobe.pytorch.common.utils import load_pt - data_value = load_pt(data_path).detach() + data_value = load_pt(data_path, True).detach() if data_value.dtype == torch.bfloat16: data_value = data_value.to(torch.float32) data_value = data_value.numpy() @@ -99,7 +111,7 @@ class MSComparator(Comparator): elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name: return self.api_replace(npu_op_name, ms_api_name, pt_api_name) else: - return npu_op_name + return npu_op_name def remove_element(self, op_name, struct, summary, idx): del op_name[idx] @@ -198,7 +210,84 @@ class MSComparator(Comparator): del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary}) return del_bench_dict - + +def sort_by_execution_sequence(npu_data, bench_data, mapping_list, flag): + def generate_execution_sequence(data): + sequence_map = {} + for index, item in enumerate(data.keys()): + if flag in item: + item_split = item.split(".") + item_name = ".".join(item_split[0:-2]) + item_index = item_split[-1] + if item_index == 'forward' or item_index == 'backward': + item_index = item_split[-2] + item_key = f"{item_name}.{item_index}" + sequence_map[item_key] = index + return sequence_map + + npu_map = generate_execution_sequence(npu_data) + bench_map = generate_execution_sequence(bench_data) + + def sort_by_map(item): + first_key = npu_map.get(item[0], 0xffff) + second_key = bench_map.get(item[1], 0xffff) + return first_key, second_key + + return sorted(mapping_list, key=sort_by_map) + + +def generate_kernel_data(map_value, data, flag): + if not map_value: + return [], [] + inputs_name = [] + outputs_name = [] + map_split = map_value.split(".") + map_name = ".".join(map_split[0:-1]) + map_index = map_split[-1] + for key, value in data.items(): + if key.find(flag) != -1 and key.find(map_name) != -1: + if key.split(".")[-1] != map_index and key.split(".")[-2] != map_index : + continue + input_args = value.get('input_args', {}) + output_args = value.get('output', {}) + for i in range(len(input_args)): + inputs_name.append(f"{key}.input.{i}") + for i in range(len(output_args)): + outputs_name.append(f"{key}.output.{i}") + return inputs_name, outputs_name + + +def generate_file_mapping(npu_json_path, bench_json_path, mapping_list): + + npu_data = load_json(npu_json_path).get("data", {}) + bench_data = load_json(bench_json_path).get("data", {}) + + forward_data = [] + mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, "forward") + for map_value in mapping_list: + npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "forward") + bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "forward") + inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs)) + outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs)) + forward_data.extend(inputs_zip) + forward_data.extend(outputs_zip) + + backward_data = [] + mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, "backward") + for map_value in mapping_list: + npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "backward") + bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "backward") + inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs)) + outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs)) + backward_data.extend(inputs_zip) + backward_data.extend(outputs_zip) + + kernel_data = forward_data + backward_data + result = {key: value for key, value in kernel_data if key is not None} + + return result + + def ms_compare(input_param, output_path, **kwargs): try: stack_mode = kwargs.get('stack_mode', False) @@ -206,6 +295,11 @@ def ms_compare(input_param, output_path, **kwargs): fuzzy_match = kwargs.get('fuzzy_match', False) cell_mapping = kwargs.get('cell_mapping', None) api_mapping = kwargs.get('api_mapping', None) + data_mapping = kwargs.get('data_mapping', None) + layer_mapping = kwargs.get('layer_mapping', None) + + pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK) + ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK) summary_compare, md5_compare = task_dumppath_get(input_param) check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True)) create_directory(output_path) @@ -213,7 +307,13 @@ def ms_compare(input_param, output_path, **kwargs): except (CompareException, FileCheckException) as error: logger.error('Compare failed. Please check the arguments and do it again!') raise CompareException(error.code) from error - ms_comparator = MSComparator(cell_mapping, api_mapping) + if layer_mapping: + mapping = load_yaml(layer_mapping) + ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct) + pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct) + layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping) + data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping) + ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping) ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, md5_compare=md5_compare) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index d1fd414a05c00dc310b99d2200c3336e00f2fed7..5d89d829e8c3fddb2f40bd11f24fdaf56de635df 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -5,13 +5,28 @@ from msprobe.pytorch.common.log import logger from msprobe.core.common.exceptions import FileCheckException from msprobe.core.compare.acc_compare import Comparator from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, CompareException -from msprobe.core.common.file_utils import FileChecker, create_directory +from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml from msprobe.pytorch.common.utils import load_pt class PTComparator (Comparator): - def __init__(self): + def __init__(self, data_mapping=None): self.frame_name = PTComparator.__name__ + self.data_mapping = data_mapping + if isinstance(self.data_mapping, str) or self.data_mapping is None: + self.data_mapping_dict = self.load_mapping_file(self.data_mapping) + elif isinstance(self.data_mapping, dict): + self.data_mapping_dict = self.data_mapping + else: + raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " + f"{type(self.data_mapping)}") + + def load_mapping_file(self, mapping_file): + if isinstance(mapping_file, str): + mapping_dict = load_yaml(mapping_file) + else: + mapping_dict = {} + return mapping_dict def read_npy_data(self, dir_path, file_name): data_path = os.path.join(dir_path, file_name) @@ -35,16 +50,17 @@ class PTComparator (Comparator): return data_value -def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False): +def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs): try: summary_compare, md5_compare = task_dumppath_get(input_param) check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True)) create_directory(output_path) check_compare_param(input_param, output_path, summary_compare, md5_compare) + data_mapping = kwargs.get('data_mapping', None) except (CompareException, FileCheckException) as error: logger.error('Compare failed. Please check the arguments and do it again!') raise CompareException(error.code) from error - pt_comparator = PTComparator() + pt_comparator = PTComparator(data_mapping) pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, md5_compare=md5_compare) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index 93756857ada4511b0eb2d25e5763b7fe877c9a53..14e914a10e3705dbd408534ff29c159be621f67f 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -38,7 +38,8 @@ from msprobe.core.common.utils import (CompareException, check_json_file, check_regex_prefix_format_valid, get_dump_data_path, - task_dumppath_get) + task_dumppath_get, + struct_json_get) from msprobe.core.common.file_utils import (FileCheckConst, FileCheckException, @@ -353,3 +354,63 @@ class TestUtils(TestCase): f.write("Hello, World!") self.assertEqual(get_file_content_bytes('test.txt'), b"Hello, World!") os.remove('test.txt') + + def test_struct_json_get_pt_invalid_path_then_fail(self): + input_param = { + "bench_json_path": None + } + framework = Const.PT_FRAMEWORK + with self.assertRaises(CompareException) as context: + stack, construct = struct_json_get(input_param, framework) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + + def test_struct_json_get_ms_invalid_path_then_fail(self): + input_param = { + "npu_json_path": None + } + framework = Const.MS_FRAMEWORK + with self.assertRaises(CompareException) as context: + stack, construct = struct_json_get(input_param, framework) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + + @patch('msprobe.core.common.utils.load_json') + def test_struct_json_get_pt_framework_valid_paths_then_pass(self, mock_load_json): + def load_json_side_effect(path): + if path.endswith("stack.json"): + return {'stack_key': 'stack_value'} + elif path.endswith("construct.json"): + return {'construct_key': 'construct_value'} + else: + raise FileNotFoundError + # 设置框架为PT_FRAMEWORK + input_param = {} + framework = Const.PT_FRAMEWORK + prefix = 'bench' + frame_json_path = '/xx/xxx' + input_param[f'{prefix}_json_path'] = frame_json_path + mock_load_json.side_effect = load_json_side_effect + + stack, construct = struct_json_get(input_param, framework) + self.assertEqual(stack, {'stack_key': 'stack_value'}) + self.assertEqual(construct, {'construct_key': 'construct_value'}) + + @patch('msprobe.core.common.utils.load_json') + def test_struct_json_get_ms_framework_valid_paths_then_pass(self, mock_load_json): + def load_json_side_effect(path): + if path.endswith("stack.json"): + return {'stack_key': 'stack_value'} + elif path.endswith("construct.json"): + return {'construct_key': 'construct_value'} + else: + raise FileNotFoundError + # 设置框架为PT_FRAMEWORK + input_param = {} + framework = Const.MS_FRAMEWORK + prefix = 'npu' + frame_json_path = '/xx/xxx' + input_param[f'{prefix}_json_path'] = frame_json_path + mock_load_json.side_effect = load_json_side_effect + + stack, construct = struct_json_get(input_param, framework) + self.assertEqual(stack, {'stack_key': 'stack_value'}) + self.assertEqual(construct, {'construct_key': 'construct_value'})