diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index 04ede42528d253bf023a6614fd44c4de55460df6..8589e86f6ca2568c77f519aecb0e04a95160ec3c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -117,11 +117,11 @@ class CellProcessor: cells_and_names_in_graph_mode = [] for index, cells_and_names in cells_with_index_in_graph_mode.items(): model = models if index == "-1" else models[int(index)] - for name, cell in cells_and_names: + for name, cell, parent_cell in cells_and_names: if cell == model: continue cell_index = (index + Const.SEP) if index != "-1" else "" - cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell)) + cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell, parent_cell)) if cells_and_names_in_graph_mode: Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index bac46a194299326218f58864768477c3ac5fc2c3..3a58c516472066acf73fb58fc7ec4259da288a01 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -258,14 +258,14 @@ def is_decorated_by_jit(func): @recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names') -def get_cells_and_names(model, cells_set=None, name_prefix=''): +def get_cells_and_names(model, cells_set=None, name_prefix='', parent_cell=None): cells_set = cells_set if cells_set else set() if model in cells_set: return cells_set.add(model) jit_decorated = is_decorated_by_jit(model.construct) - yield name_prefix, model, jit_decorated + yield name_prefix, model, jit_decorated, parent_cell if jit_decorated: return @@ -275,9 +275,9 @@ def get_cells_and_names(model, cells_set=None, name_prefix=''): cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name jit_decorated = is_decorated_by_jit(model.construct) if jit_decorated: - yield cells_name_prefix, cell, jit_decorated + yield cells_name_prefix, cell, jit_decorated, model else: - for ele in get_cells_and_names(cell, cells_set, cells_name_prefix): + for ele in get_cells_and_names(cell, cells_set, cells_name_prefix, model): yield ele @@ -288,9 +288,9 @@ def get_cells_and_names_with_index(models): def distinguish_cells(cells): cells_in_pynative_mode = [] cells_in_graph_mode = [] - for name, cell, jit_decorated in cells: + for name, cell, jit_decorated, parent_cell in cells: if jit_decorated: - cells_in_graph_mode.append((name, cell)) + cells_in_graph_mode.append((name, cell, parent_cell)) else: cells_in_pynative_mode.append((name, cell)) return cells_in_pynative_mode, cells_in_graph_mode diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py index 2e152f4dba3a78ce7c933dd438d832bf7286699e..bc7162af97893c92ee28d2e62cb482177a80bf71 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py @@ -38,6 +38,8 @@ DEFAULT_RANK_DIR = "rank0" KEY_LAYERS = "layers" construct = {} cell_list = [] +free_cells = {} +parent_cell_types = {} KEY_SIDE_EFFECT = "side_effect_io" KEY_TOPLAYER = "TopLayer" KEY_FORWARD = CoreConst.FORWARD @@ -302,7 +304,24 @@ def check_relation(cell_name, parent_cell_name): return False +def get_parent_cell_name(child_cell_name): + parent_cell_name = '' + + last_dot_index = child_cell_name.rfind(CoreConst.SEP) + if last_dot_index == -1: + return parent_cell_name + + layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$" + if re.search(layers_pattern, child_cell_name): + parent_cell_name = re.sub(layers_pattern, '', child_cell_name) + else: + parent_cell_name = child_cell_name[:last_dot_index] + + return parent_cell_name + + def get_construct(cell_list_input): + global free_cells, parent_cell_types for cell in cell_list_input: cell_name = get_cell_name(cell) cell_data_mode = get_data_mode(cell) @@ -316,7 +335,20 @@ def get_construct(cell_list_input): found_flag = True break if not found_flag: - construct.update({cell: None}) + cell_name_with_mode = f'{cell_name}{CoreConst.SEP}{cell_data_mode}' + if cell_name_with_mode in free_cells: + construct.update({cell: free_cells.get(cell_name_with_mode)}) + continue + + parent_cell = None + parent_cell_name = get_parent_cell_name(cell_name) + if parent_cell_name and cell_name in parent_cell_types: + parent_cell = CoreConst.SEP.join([CoreConst.CELL, parent_cell_name, parent_cell_types.get(cell_name)]) + second_last_dot_index = cell.rfind(CoreConst.SEP, 0, cell.rfind(CoreConst.SEP)) + parent_cell = f'{parent_cell}{cell[second_last_dot_index:]}' + free_cells[cell_name_with_mode] = parent_cell + + construct.update({cell: parent_cell}) def generate_construct(path): @@ -470,7 +502,7 @@ def process_csv(path): if col_name in columns: value = convert_special_values(row[col_name]) tensor_json[json_key] = value - + if io_key == KEY_INPUT: data_info.append([op_name, CoreConst.INPUT_ARGS, tensor_json]) elif io_key == KEY_OUTPUT: @@ -794,7 +826,7 @@ def create_kbyk_json(dump_path, summary_mode, step): def start(config: CellDumpConfig): - global dump_task + global dump_task, parent_cell_types dump_task = config.task net = config.net dump_path = config.dump_path @@ -822,7 +854,7 @@ def start(config: CellDumpConfig): return if isinstance(net, nn.Cell): - net = (('', net),) + net = (('', net, None),) td_config_path = "" try: @@ -845,6 +877,7 @@ def start(config: CellDumpConfig): black_list = ["grad_reducer", ""] for name_and_model in net: + parent_cell_types[name_and_model[0]] = name_and_model[2].__class__.__name__ for name, cell in name_and_model[1].cells_and_names(name_prefix=name_and_model[0]): class_name = cell.__class__.__name__ # 跳过黑名单cell