diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 22e02d0d98d525d8b339745054d05401c58a4fca..6775c028af410527296378eca67462720540ad5a 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -192,12 +192,14 @@ class Const: TORCH_FLOAT32 = "torch.float32" TORCH_BFLOAT16 = "torch.bfloat16" + TYPE = 'type' DTYPE = 'dtype' SHAPE = 'shape' MAX = 'Max' MIN = 'Min' MEAN = 'Mean' NORM = 'Norm' + DATA_NAME = 'data_name' CODE_STACK = 'Code Stack' OP_NAME = 'Op Name' @@ -210,6 +212,8 @@ class Const: SCOPE_SEPARATOR = "/" REPLACEMENT_CHARACTER = "_" + FORWARD_PATTERN = SEP + FORWARD + SEP + BACKWARD_PATTERN = SEP + BACKWARD + SEP class CompareConst: diff --git a/debug/accuracy_tools/msprobe/mindspore/common/const.py b/debug/accuracy_tools/msprobe/mindspore/common/const.py index c99095e99a8650c0fd678ea7b9763c244cf97f5f..2193e0ac10f0eff1ca9f046ffbf250ef7b2639a5 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/const.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/const.py @@ -59,6 +59,7 @@ class Const: DROPOUT_API_NAME_PREFIX = "dropout" GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT] + GRAPH_CELL_DUMP_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.FORWARD, CoreConst.BACKWARD] HOOK_MS_PREFIX_DICT = { OPS_DATA_PREFIX: OPS_PREFIX, diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index d928e0a3a62504f07ffb8807a43d405773921861..4efd9ffefd8e9a4876025781c39ced99507b8c4b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -29,7 +29,7 @@ from msprobe.mindspore.ms_config import parse_json_config from msprobe.mindspore.runtime import Runtime from msprobe.mindspore.service import Service from msprobe.mindspore.task_handler_factory import TaskHandlerFactory - +from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump class PrecisionDebugger: _instance = None @@ -93,7 +93,7 @@ class PrecisionDebugger: instance.service.start(model) else: if not instance.first_start: - handler = TaskHandlerFactory.create(instance.config) + handler = TaskHandlerFactory.create(instance.config, model) handler.handle() instance.first_start = True @@ -129,6 +129,9 @@ class PrecisionDebugger: raise Exception(MsgConst.NOT_CREATED_INSTANCE) if instance.task in PrecisionDebugger.task_not_need_service: return + if instance.config.execution_mode == MsConst.GRAPH_GE_MODE or instance.config.execution_mode == MsConst.GRAPH_KBYK_MODE: + GraphModeCellDump.step() + return if instance.service: instance.service.step() HOOKCell.cell_count = defaultdict(int) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py new file mode 100644 index 0000000000000000000000000000000000000000..1b7617555588b8bcb4e3147a1c67c9c39f448d75 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py @@ -0,0 +1,357 @@ +from mindspore import nn, ops +from mindspore.communication import init, get_rank +import mindspore as ms +import inspect +import os +import time +import re +import json +import atexit +import numpy as np +from multiprocessing import Pool +from msprobe.mindspore.common.log import logger +from msprobe.core.common.const import Const as CoreConst +from msprobe.core.common.file_utils import load_npy, save_json +from msprobe.core.common.const import FileCheckConst + + +CONSTRUCT_FILE_NAME = "construct.json" +DEFAULT_RANK_DIR = "rank0" +KEY_LAYERS = "layers" +pattern = re.compile(r'(\d+)_(\w+)_(\d+)') +np_ms_dtype_dict = { + "int8": ms.int8, + "byte": ms.byte, + "int16": ms.int16, + "short": ms.short, + "int32": ms.int32, + "intc": ms.intc, + "int64": ms.int64, + "intp": ms.intp, + "uint8": ms.uint8, + "ubyte": ms.ubyte, + "uint16": ms.uint16, + "ushort": ms.ushort, + "uint32": ms.uint32, + "uintc": ms.uintc, + "uint64": ms.uint64, + "uintp": ms.uintp, + "float16": ms.float16, + "half": ms.half, + "float32": ms.float32, + "single": ms.single, + "float64": ms.float64, + "double": ms.double, + "bfloat16": ms.bfloat16, + "complex64": ms.complex64, + "complex128": ms.complex128 +} + +def generate_file_path(dump_path, cell_prefix, suffix, io_type, index): + step_path = os.path.join(dump_path, "{step}") + rank_path = os.path.join(step_path, "{rank}") + data_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA) + file_name = CoreConst.SEP.join([cell_prefix, suffix, io_type, str(index)]) + return os.path.join(data_path, file_name) + + +def partial_func(func, dump_path, cell_prefix, index, io_type): + def newfunc(*args, **kwargs): + return func(dump_path, cell_prefix, index, io_type, *args, **kwargs) + return newfunc + + +def clip_gradient(dump_path, cell_prefix, index, io_type, dx): + if io_type == CoreConst.OUTPUT: + ops.TensorDump()(generate_file_path(dump_path, cell_prefix, CoreConst.BACKWARD, io_type, index), dx) + if io_type == CoreConst.INPUT: + ops.TensorDump("in")(generate_file_path(dump_path, cell_prefix, CoreConst.BACKWARD, io_type, index), dx) + return dx + + +def cell_construct_wrapper(func, self): + def new_construct(self, *args, **kwargs): + new_args = [] + + # The inputs of the cell. + for index, item in enumerate(args): + if self.data_mode == "backward" or self.data_mode == "all": + if ops.is_tensor(item): + item = self.input_clips[index](item) + if self.data_mode == "forward" or self.data_mode == "all": + if ops.is_tensor(item): + ops.TensorDump("in")(generate_file_path(self.dump_path, self.cell_prefix, CoreConst.FORWARD, CoreConst.INPUT, index), item) + new_args.append(item) + + out = func(*new_args, **kwargs) + + # The outputs of the cell. + if isinstance(out, tuple): + for index, item in enumerate(out): + if self.data_mode == "backward" or self.data_mode == "all": + if ops.is_tensor(item): + item = self.output_clips[index](item) + if self.data_mode == "forward" or self.data_mode == "all": + if ops.is_tensor(item): + ops.TensorDump()(generate_file_path(self.dump_path, self.cell_prefix, CoreConst.FORWARD, CoreConst.OUTPUT, index), item) + else: + if self.data_mode == "backward" or self.data_mode == "all": + out = self.output_clips[0](out) + if self.data_mode == "forward" or self.data_mode == "all": + ops.TensorDump()(generate_file_path(self.dump_path, self.cell_prefix, CoreConst.FORWARD, CoreConst.OUTPUT, 0), out) + + return out + + return new_construct.__get__(self, type(self)) + + +def rename_filename(path): + # Get all the file names in the folder and sort them by the number after "_". + filenames = os.listdir(path) + pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$') + filenames.sort(key=lambda x: int(pattern.findall(x)[0])) + + filename_dict = {} + for filename in filenames: + match = re.search(rf"{CoreConst.CELL}{CoreConst.SEP}(.*?){CoreConst.SEP}({CoreConst.INPUT}|{CoreConst.OUTPUT}){CoreConst.SEP}",filename) + mid_field = match.group(1) + + if mid_field in filename_dict: + filename_dict[mid_field].append(filename) + else: + filename_dict[mid_field] = [filename] + + # Change the file name and add the sequence number of the cell that is repeatedly called. + for mid_field, filename_list in filename_dict.items(): + last_second_index = filename_list[0].rfind(CoreConst.REPLACEMENT_CHARACTER, 0, filename_list[0].rfind(CoreConst.REPLACEMENT_CHARACTER)) + first_file_sub = filename_list[0][:last_second_index] + index_list = [] + for index, filename in enumerate(filename_list): + if first_file_sub in filename: + index_list.append(index) + index_list.append(len(filename_list)) + + for i in range(len(index_list) - 1): + start_index = index_list[i] + end_index = index_list[i + 1] + for j in range(start_index, end_index): + if CoreConst.FORWARD_PATTERN in filename_list[j]: + newFileName = filename_list[j].replace(CoreConst.FORWARD_PATTERN, CoreConst.FORWARD_PATTERN + str(i) + CoreConst.SEP) + if CoreConst.BACKWARD_PATTERN in filename_list[j]: + newFileName = filename_list[j].replace(CoreConst.BACKWARD_PATTERN, CoreConst.BACKWARD_PATTERN + str(i) + CoreConst.SEP) + os.rename(os.path.join(path,filename_list[j]), os.path.join(path,newFileName)) + logger.info(f"==========The rename_filename phase is Finished!==========") + + +# Extract the field between the first "." and the third to last ".", i.e. {cell_name} +def get_cell_name(str): + parts = str.split(CoreConst.SEP) + if len(parts) < 4: + return None + start_index = 1 + end_index = len(parts) - 3 + return CoreConst.SEP.join(parts[start_index:end_index]) + + +# Extract the field between the last "." and the second to last ".", i.e. {data_made} +def get_data_mode(str): + last_dot_index = str.rfind(CoreConst.SEP) + second_last_dot_index = str.rfind(CoreConst.SEP, 0, last_dot_index) + data_mode = str[second_last_dot_index + 1:last_dot_index] + return data_mode + + +# Determine whether there is a father son relationship between the two. +def check_relation(cell_name, parent_cell_name): + layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$" + last_dot_index = cell_name.rfind(CoreConst.SEP) + if last_dot_index != -1: + sub_cell_name = cell_name[:last_dot_index] + if sub_cell_name == parent_cell_name: + return True + elif re.search(layers_pattern, cell_name): + sub_cell_name = re.sub(layers_pattern, '', cell_name) + if sub_cell_name == parent_cell_name: + return True + return False + + +construct={} +def get_construct(cell_list): + parent_cell_Stack = [] + for cell in cell_list: + cell_name = get_cell_name(cell) + cell_data_mode = get_data_mode(cell) + found_flag = False + for parent_cell in reversed(parent_cell_Stack): + parent_cell_name = get_cell_name(parent_cell) + parent_data_mode = get_data_mode(parent_cell) + has_relation = check_relation(cell_name, parent_cell_name) + if has_relation and parent_data_mode == cell_data_mode: + construct.update({cell: parent_cell}) + found_flag = True + break + if not found_flag: + construct.update({cell: None}) + parent_cell_Stack.append(cell) + + +def generate_construct(path): + global construct + # Get all the file names in the folder and sort them by the number after "_". + filenames = os.listdir(path) + pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$') + filenames.sort(key=lambda x: int(pattern.findall(x)[0])) + + cell_list = [] + for filename in filenames: + match = re.search(rf"({CoreConst.CELL}{CoreConst.SEP}).*?(?={CoreConst.SEP}(?:{CoreConst.INPUT}{CoreConst.SEP}|{CoreConst.OUTPUT}{CoreConst.SEP}))", filename) + mid_field = match.group(0) + if mid_field not in cell_list: + cell_list.append(mid_field) + + get_construct(cell_list) + + # Generate JSON file + rank_dir = os.path.dirname(path) + json_path = os.path.join(rank_dir, CONSTRUCT_FILE_NAME) + with open(json_path, "w") as f: + json.dump(construct, f, indent=1) + + # Clear 'construct' and continue processing data for the next path + construct={} + logger.info(f"==========The genarete_construct phase is Finished!==========") + + +def process_file(file_path): + try: + # 读取.npy文件内容 + npy_content = load_npy(file_path) + logger.info(f"Loaded {file_path}: shape is {npy_content.shape}, dtype is {npy_content.dtype}") + + # 文件名举例:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0_float32_165.npy + parts = os.path.basename(file_path).split('.') + # 使用正则表达式搜索文件名 + data_dtype = "" + match = pattern.search(parts[-2]) + if match: + # 提取数据类型 + data_dtype = match.group(2) + + # op_name是Cell.network._backbone.loss.CrossEntropyLoss.forward.0 + op_name = '.'.join(parts[:-3]) + tensor_json = { + CoreConst.TYPE: 'mindspore.Tensor', + CoreConst.DTYPE: str(np_ms_dtype_dict.get(data_dtype)), + CoreConst.SHAPE: list(npy_content.shape), + CoreConst.MAX: npy_content.max().item(), + CoreConst.MIN: npy_content.min().item(), + CoreConst.MEAN: npy_content.mean().item(), + CoreConst.NORM: np.linalg.norm(npy_content).item(), + CoreConst.DATA_NAME: os.path.basename(file_path) + } + + # 根据文件名的最后一个部分(输入或输出)确定是添加到input_args还是output + if parts[-3] == CoreConst.INPUT: + return op_name, CoreConst.INPUT_ARGS, tensor_json + elif parts[-3] == CoreConst.OUTPUT: + return op_name, CoreConst.OUTPUT, tensor_json + else: + return None, None, None + + except Exception as e: + logger.error(f"Error reading {file_path}: {e}") + return None, None, None + + +def generate_dump_info(path): + if not os.path.exists(path): + logger.error("The provided path does not exist.") + return + + dump_data = {"task": "tensor", "level": "L0", "dump_data_dir": path, "data": {}} + + with Pool(processes=10) as pool: + file_paths = [(os.path.join(root, file),) for root, dirs, files in os.walk(path) for file in files if + file.endswith(FileCheckConst.NUMPY_SUFFIX)] + file_paths.sort() + results = pool.starmap(process_file, file_paths) + + # 收集结果 + for op_name, key, tensor_json in results: + if op_name: + if op_name not in dump_data.get(CoreConst.DATA): + dump_data[CoreConst.DATA][op_name] = {CoreConst.INPUT_ARGS: [], + CoreConst.INPUT_KWARGS: {}, + CoreConst.OUTPUT: []} + dump_data[CoreConst.DATA][op_name][key].append(tensor_json) + + # 将数据写入dump.json + save_json(os.path.join(os.path.dirname(path), 'dump.json'), dump_data, indent=1) + + logger.info(f"Dump data saved to {os.path.join(os.path.dirname(path), 'dump.json')}") + + +def generate_stack_info(path): + if not os.path.exists(path): + logger.error("The provided path does not exist.") + return + + stack_data = {} + file_paths = [os.path.join(root, file) for root, _, files in os.walk(path) for file in files if + file.endswith(FileCheckConst.NUMPY_SUFFIX)] + file_paths.sort() + for file_path in file_paths: + # 文件名举例:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0_float32_165.npy + parts = os.path.basename(file_path).split('.') + # op_name是Cell.network._backbone.loss.CrossEntropyLoss.forward.0 + op_name = '.'.join(parts[:-3]) # 获取所需的部分路径 + stack_data.update({op_name: []}) + + # 将数据写入dump.json + save_json(os.path.join(os.path.dirname(path), 'stack.json'), stack_data, indent=1) + + logger.info(f"Stack data saved to {os.path.join(os.path.dirname(path), 'stack.json')}") + + +def process(dump_path): + time.sleep(10) + logger.info(f"==========Start processing data that has already been stored on the disk!==========") + rank_id = os.environ.get('RANK_ID') + rank_dir = DEFAULT_RANK_DIR + if rank_id != None: + rank_dir = CoreConst.RANK + str(rank_id) + + step_dir_list = os.listdir(dump_path) + for step_dir in step_dir_list: + step_path = os.path.join(dump_path, step_dir) + rank_path = os.path.join(step_path, rank_dir) + npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA) + rename_filename(npy_path) + generate_construct(npy_path) + generate_dump_info(npy_path) + generate_stack_info(npy_path) + + +def start(net=None, dump_path="./", data_mode=CoreConst.ALL): + if net == None: + return + for name, cell in net.cells_and_names(): + if name == "": + continue + else: + cell.cell_prefix = CoreConst.SEP.join([CoreConst.CELL, name, cell.__class__.__name__]) + + cell.construct = cell_construct_wrapper(cell.construct, cell) + logger.info(f"Cell {name}: construct function is wrapped!") + cell.dump_path = dump_path + cell.data_mode = data_mode + cell.input_clips = [] + cell.output_clips = [] + # It is assumed that each cell has a maximum of 50 outputs and 50 inputs. + for i in range(50): + cell.input_clips.append(ops.InsertGradientOf(partial_func(clip_gradient, cell.dump_path, cell.cell_prefix, i, CoreConst.INPUT))) + cell.output_clips.append(ops.InsertGradientOf(partial_func(clip_gradient, cell.dump_path, cell.cell_prefix, i, CoreConst.OUTPUT))) + + logger.info(f"==========The cell_dump_process_start phase is Finished!==========") + atexit.register(process, dump_path=dump_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py index 0ca63b4a84aee00127bca37b7da36888e905a5aa..ef126ea38130060224f7eb27b95b8c3c28914492 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py @@ -17,13 +17,14 @@ from msprobe.mindspore.common.const import Const from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump +from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump class DumpToolFactory: tools = { Const.CELL: { - Const.GRAPH_KBYK_MODE: None, - Const.GRAPH_GE_MODE: None, + Const.GRAPH_KBYK_MODE: GraphModeCellDump, + Const.GRAPH_GE_MODE: GraphModeCellDump, Const.PYNATIVE_MODE: None }, Const.API: { @@ -39,9 +40,13 @@ class DumpToolFactory: } @staticmethod - def create(config: DebuggerConfig): - if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST: - raise Exception("data_mode must be one of all, input, output.") + def create(config: DebuggerConfig, model): + if config.level == Const.CELL: + if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST: + raise Exception("data_mode must be one of all, forward, backward.") + else: + if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST: + raise Exception("data_mode must be one of all, input, output.") tool = DumpToolFactory.tools.get(config.level) if not tool: raise Exception("Valid level is needed.") @@ -49,4 +54,7 @@ class DumpToolFactory: if not tool: raise Exception(f"Data dump is not supported in {config.execution_mode} mode " f"when dump level is {config.level}.") - return tool(config) + if tool == GraphModeCellDump: + return tool(config, model) + else: + return tool(config) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..872d3ff0f968267d82ce9fa042cd9d9aca4f7f6b --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +import mindspore as ms +from mindspore._c_expression import _tensordump_set_step +from mindspore.ops.primitive import _run_op +from mindspore import hal, ops +import os +import msprobe.mindspore.dump.cell_dump_process as cellDumper +from msprobe.mindspore.common.const import Const + + +class GraphModeCellDump: + def __init__(self, config: DebuggerConfig, model): + self.net = model + self.white_list = [] + self.black_list = [] + self.dump_path = config.dump_path if config.dump_path else "./" + self.rank = config.rank + self.step = config.step + self.scope = config.scope + self.list = config.list + self.data_mode = config.data_mode + self.file_format = config.file_format + self.check_config() + self.set_step() + + + def check_config(self): + if self.rank != []: + raise Exception("In graph mode, cell dump does not currently support specifying rank.") + if self.scope != []: + raise Exception("In graph mode, cell dump does not currently support specifying scope.") + if self.list != []: + raise Exception("In graph mode, cell dump does not currently support specifying list.") + if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST: + raise Exception("In graph mode and cell dump, data_mode must be one of all, forword, backword.") + if self.file_format != "npy": + raise Exception("In graph mode and cell dump, file_format must be npy") + if not self.net: + logger.warning("The model is empty and Cell dump is not enabled.") + return True + + + def set_step(self): + _tensordump_set_step(self.step) + + + def handle(self): + os.environ['MS_JIT_MODULES'] = 'msprobe' + cellDumper.start(net=self.net, dump_path=self.dump_path, data_mode=self.data_mode[0]) + + + @staticmethod + def step(): + hal.synchronize() + temp_tensor = ms.Tensor([1], dtype=ms.float32) + step_flag = "" + _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor)) + ops.tensordump(step_flag, temp_tensor) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py index a9cb5e6dd4037dcdeffe3c4d9584ad93c42022d6..5cfbbaeb4a46a197ca6c5eb163db4a27aa1e5c35 100644 --- a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py @@ -29,11 +29,14 @@ class TaskHandlerFactory: } @staticmethod - def create(config: DebuggerConfig): + def create(config: DebuggerConfig, model): task = TaskHandlerFactory.tasks.get(config.task) if not task: raise Exception("Valid task is needed.") - handler = task.create(config) + if task == DumpToolFactory: + handler = task.create(config, model) + else: + handler = task.create(config) if not handler: raise Exception("Can not find task handler") return handler