diff --git a/debug/accuracy_tools/msprobe/find_first/__init__.py b/debug/accuracy_tools/msprobe/find_first/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..888d97b36882d653c9f4db62369cde5db447c8dc --- /dev/null +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -0,0 +1,256 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections import defaultdict +import os +from itertools import dropwhile, chain + +from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const +from msprobe.find_first.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, DiffAnalyseConst, + analyze_diff_in_group) +from msprobe.find_first.graph import DataNode, CommunicationNode + + +class DiffAnalyzer: + def __init__(self, input_path, output_path): + self._input_path = input_path + self._output_path = output_path + self._paths = {} + self._resolve_input_path() + self._diff_nodes = [] # 记录所有异常节点 + self._cache = FileCache() + self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id + self._after_comm_diffs = {} # 记录各rank下发生在通信节点之后的异常计算节点 + self._rank_comm_nodes_dict = {} # 记录各rank的通信节点 + + def analyze(self): + for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]: + analyze_func() + if self._diff_nodes: + self._gen_analyze_info() + return + logger.info('Cannot find any diff node, no need to generate analyze file.') + + """ + 这里需要生成stack,但是直接用dict中自带就行,在op_items.NPU_Stack_Info中 + """ + def _resolve_input_path(self): + contents = os.listdir(self._input_path) + for path in contents: + if not path.startswith('rank'): + continue + rank_str = path.strip('rank') + if not rank_str: + rank = 0 + elif not rank_str.isdigit(): + continue + else: + rank = int(rank_str) + dump_path = os.path.join(self._input_path, path, DiffAnalyseConst.DUMP_FILE) + self._paths[rank] = RankPath(rank, dump_path) + + def _pre_analyze(self): + logger.info('Start searching diff node before communication.') + for path in self._paths.values(): + dump_data = self._cache.load_json(path.dump_path) + if not dump_data: + logger.warning(f'Rank {path.rank} has no dump data!') + continue + for op_name, op_data in dump_data.items(): + if is_communication_op(op_name): + self._first_comm_nodes[path.rank] = op_name + break + data_node = DataNode(op_name, path.rank, op_data) + if data_node.is_diff: + self._diff_nodes.append(data_node) + break + + def _analyze(self): + logger.info('Start searching diff node during communication.') + self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths} + self._connect_comm_nodes() + self._pruning() + self._search_first_diff() + + def _post_analyze(self): + logger.info('Start searching diff node after communication.') + for nodes in self._after_comm_diffs.values(): + if nodes: + self._diff_nodes.append(nodes[0]) + + def _connect_comm_nodes(self): + searched_ranks = set() + for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]: + searched_ranks.add(rank) + seen_nodes = set() + for cur_node in nodes.values(): + conn_info = cur_node.find_connected_nodes() + if not conn_info.get('ranks'): + conn_info['ranks'] = self._rank_comm_nodes_dict.keys() + if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes): + logger.info(f'Cannot find connected communication node for "{cur_node.node_id}".') + + def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes): + def connect(search_node): + seen_nodes.add(search_node.node_id) + if search_node.type == DiffAnalyseConst.DST: + cur_node.add_dst(search_node) + elif search_node.type == DiffAnalyseConst.SRC: + search_node.layer = cur_node.layer + search_node.add_dst(cur_node) + else: + cur_node.add_link(search_node) + + found = cur_node.connected + for connected_rank in conn_info['ranks']: + if connected_rank in searched_ranks: + continue + tar_id_prefix = f'{connected_rank}.{conn_info["api"]}' + for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items(): + if search_id in seen_nodes: + continue + if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')): + continue + search_conn_ranks = search_node.find_connected_nodes().get('ranks') + if ((not search_conn_ranks and search_node.api not in DiffAnalyseConst.DIRECTED_API) or + cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank + connect(search_node) + found = True + break + return found + + def _analyze_comm_nodes(self, rank): + path = self._paths[rank] + data = self._cache.load_json(path.dump_path) + communication_nodes = {} + if rank not in self._first_comm_nodes: # 此rank没有通信节点 + return communication_nodes + last_node_id = None # 记录上一个通信节点的node_id + compute_ops = [] # 记录两个通信节点之间的计算节点 + sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数 + for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data): + node_id = f'{rank}.{op_name}' + op_data = data[op_name] + if is_communication_op(op_name): + comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer), + compute_ops=compute_ops) + if last_node_id: + communication_nodes.get(last_node_id).add_next(comm_node) + communication_nodes[node_id] = comm_node + last_node_id = node_id + compute_ops = [] + sub_layer = 0 + elif not is_ignore_op(op_name): + data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer) + if data_node.is_diff: + compute_ops.append(data_node) + sub_layer += 1 + if compute_ops: + self._after_comm_diffs[rank] = compute_ops + return communication_nodes + + def _pruning(self): + deleted_node_id = [] + for nodes in self._rank_comm_nodes_dict.values(): + for node_id in list(nodes.keys()): + node = nodes[node_id] + if node.is_diff or node.compute_ops: + continue + deleted_node_id.append(node_id) + node.delete() + del nodes[node_id] + logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]') + + def _search_first_diff(self): + nodes_queues = [] + for comm_nodes in self._rank_comm_nodes_dict.values(): + nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer)) + seen_nodes = set() + + def get_next_node(node_list): + while node_list: + next_node = node_list.pop(0) + if next_node.node_id not in seen_nodes: + return next_node + return None + + def find_all_members(ori_node): + ids = get_relative_ids(ori_node) + id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids])) + while id_queue: + new_id = id_queue.pop(0) + ids.add(new_id) + id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids)) + return ids + + def get_relative_ids(ori_node): + if not ori_node: + return set() + return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() | + ori_node.dst_nodes.keys()) + + while any(nodes_queues): + groups = [] + all_ids_in_groups = set() + for nodes in nodes_queues: + node = get_next_node(nodes) + if not node: + continue + if not groups or node.node_id in all_ids_in_groups: + new_group = find_all_members(node) + groups.append(new_group) + all_ids_in_groups.update(new_group) + for group in groups: + seen_nodes.update(group) + self._diff_nodes.extend(analyze_diff_in_group([self._get_node_by_id(n_id) for n_id in group])) + if self._diff_nodes: + self._diff_nodes = [min(self._diff_nodes, key=lambda x: (x.layer, x.sub_layer))] + return + + def _get_node_by_id(self, node_id): + splits = node_id.split(Const.SEP, 1) + if len(splits) < 2 or not splits[0].isdigit(): + logger.error(f'invalid node_id {node_id}') + raise RuntimeError(f'invalid node_id {node_id}') + rank = int(splits[0]) + return self._rank_comm_nodes_dict.get(rank, {}).get(node_id) + + def _gen_analyze_info(self): + if not os.path.exists(self._output_path): + make_dir(self._output_path) + file_name = f'diff_analyze_{time.time_ns()}.json' + result_file = os.path.join(self._output_path, file_name) + result_content = defaultdict(list) + for node in self._diff_nodes: + result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank])) + save_json(result_file, result_content, 2) + logger.info(f"The analyze result is saved in: {result_file}") + + +def _diff_analyze_parser(parser): + parser.add_argument("-i", "--input_path", dest="input_path", default="", type=str, + help=" The dump file path, over step level. eg: \"xxx/step_0/\".", + required=True) + parser.add_argument("-o", "--output_path", dest="output_path", default="./output", type=str, + help=" The diff analyze result output file path.", + required=False) + + +def _run_diff_analyze(args): + check_file_or_directory_path(args.input_path, True) + DiffAnalyzer(args.input_path, args.output_path).analyze() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6cb7f69e8f00dd3073a09368b86f0b10d57433 --- /dev/null +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -0,0 +1,167 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.core.common.const import CompareConst +from msprobe.find_first.utils import RankPath, DiffAnalyseConst + + +@dataclass +class DataNode: + op_name: str + rank: int + inputs: dict + outputs: dict + op_data: list + is_same: bool + layer: int = 0 # 和communication_node的layer保持一致 + sub_layer: int = 0 # 调用顺序,越小表示越先调用 + + def __init__(self, op_name, rank, op_data, **kwargs): + self.op_name = op_name + self.rank = rank + self.stack = None + self.inputs = {} + self.outputs = {} + self.is_diff = False + self.parse_data(op_data) + self.sub_layer = kwargs.get('sub_layer', 0) + + def find_stack(self): + for item in self.stack: + if len(item) >= 2 and self.op_name in item[0]: + return item[1] + return {} + + def parse_data(self, op_data): + self.is_diff = not op_data.get("is_same", True) + self.op_data = op_data.get("op_items") # 这里拿到的是比对column,是一个list,有若干行 + for cmp_data in self.op_data: + name = cmp_data.get(CompareConst.NPU_NAME) + metrics = {CompareConst.NPU_MAX: cmp_data.get(CompareConst.NPU_MAX), + CompareConst.NPU_MIN: cmp_data.get(CompareConst.NPU_MIN), + CompareConst.NPU_MEAN: cmp_data.get(CompareConst.NPU_MEAN), + CompareConst.NPU_NORM: cmp_data.get(CompareConst.NPU_NORM)} + if cmp_data.get(CompareConst.STACK) != "N/A" and not self.stack: + self.stack = cmp_data.get(CompareConst.STACK) + if "input" in name: + self.inputs[name] = metrics + elif "output" in name: + self.outputs[name] = metrics + + def gen_node_info(self, path: RankPath): + data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} + return {'op_name': self.op_name, + 'data_info': data_info_list, + 'stack_info': self.stack} + + +class CommunicationNode: + def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs): + self.node_id = node_id + self.rank = rank + self.data = data + self.is_diff = data.is_diff + self.layer = layer + op_name_split = self.data.op_name.split(Const.SEP) + if len(op_name_split) < 4: + logger.error(f'invalid op_name: {self.data.op_name}') + raise RuntimeError(f'invalid op_name: {self.data.op_name}') + self.api = op_name_split[1] + self.call_cnt = op_name_split[2] + self.pre_node = kwargs.get('pre_node') + self.link_nodes = kwargs.get('link_nodes', {}) + self.dst_nodes = kwargs.get('dst_nodes', {}) + self.src_nodes = kwargs.get('src_nodes', {}) + self.next_nodes = kwargs.get('next_nodes', {}) + self.compute_ops = kwargs.get('compute_ops', []) + self.type = self._resolve_type() + self.connected = False + + def add_next(self, node): + self.next_nodes[node.node_id] = node + node.pre_node = self + node.layer = self.layer + 1 + node.data.layer = node.layer + + def add_link(self, node): + self.link_nodes[node.node_id] = node + node.link_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def add_dst(self, node): + self.dst_nodes[node.node_id] = node + node.src_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def delete(self): + for node in self.next_nodes.values(): + node.pre_node = None + for node in self.dst_nodes.values(): + if node.src_nodes: + node.src_nodes.pop(self.node_id) + for node in self.src_nodes.values(): + if node.dst_nodes: + node.dst_nodes.pop(self.node_id) + for node in self.link_nodes.values(): + if node.link_nodes: + node.link_nodes.pop(self.node_id) + if self.pre_node: + if self.pre_node.next_nodes: + self.pre_node.next_nodes.pop(self.node_id) + + def find_connected_nodes(self): + """ + 根据 api/类型/入参/调用次数 确定相连接的node的op_name + """ + tar_api = DiffAnalyseConst.P2P_API_MAPPING.get(self.api, self.api) + ranks = set() + # 遍历DST和SRC相关的input,获取对应的rank值 + # 遍历inputs获取所有rank值 + for input_name, v in self.data.inputs.items(): + # 检查key是否包含DST/SRC相关标识 + target_types = [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP, + DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP] + if any(keyword in input_name for keyword in target_types): + # 获取NPU_MAX值并转为整数 + rank_val = int(v.get(CompareConst.NPU_MAX, 0)) + if rank_val: + ranks.add(rank_val) + elif input_name.endswith('.group'): + val = v.get(CompareConst.NPU_MAX) + if val and val.startswith('[') and val.endswith(']'): + val = [int(part) for part in val.strip('[]').split(',')] + ranks.update(val) + + return {'ranks': ranks, 'api': f'Distributed.{tar_api}', + 'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)} + + def _resolve_type(self): + # 遍历SRC和DST相关的输入,根据rank值判断节点类型 + for prefix, node_type in [(DiffAnalyseConst.SRC, DiffAnalyseConst.SRC), + (DiffAnalyseConst.DST, DiffAnalyseConst.DST)]: + for k, v in self.data.inputs.items(): + if prefix in k or f"{prefix}_GROUP" in k: + return node_type if v.get(CompareConst.NPU_MAX) == self.rank \ + else DiffAnalyseConst.OPPOSITE_DIR[node_type] + return DiffAnalyseConst.LINK diff --git a/debug/accuracy_tools/msprobe/find_first/utils.py b/debug/accuracy_tools/msprobe/find_first/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58729c118e0e51ae9c2ff730b2b540d43f18d1e0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/find_first/utils.py @@ -0,0 +1,189 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from dataclasses import dataclass +import sys +import time +import psutil + +from msprobe.core.common.file_utils import check_file_or_directory_path, load_json + + +@dataclass +class RankPath: + rank: int + dump_path: str + construct_path: str + stack_path: str + + def __init__(self, rank, dump_path): + self.rank = rank + check_file_or_directory_path(dump_path) + self.dump_path = dump_path + + +class FileCache: + """ + lazy load file + """ + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4 + self._cache = OrderedDict() + self._access_cnt = {} + self._access_time = {} + self._size = {} + + @staticmethod + def _sizeof(obj): + seen = set() + objs = [obj] + size = 0 + while objs: + obj = objs.pop() + obj_id = id(obj) + if obj_id in seen: + continue + seen.add(obj_id) + size += sys.getsizeof(obj) + if isinstance(obj, dict): + objs.extend(obj.keys()) + objs.extend(obj.values()) + elif isinstance(obj, (list, tuple, set, frozenset)): + objs.extend(obj) + return size + + def load_json(self, json_path): + if json_path in self._cache: + self._access_cnt[json_path] += 1 + self._access_time[json_path] = time.monotonic() + self._cache.move_to_end(json_path) + return self._cache[json_path] + self._cleanup() + return self._load(json_path) + + def _load(self, json_path): + data = load_json(json_path) + self._add_to_cache(json_path, data) + return data + + def _add_to_cache(self, key, data): + if key in self._cache: + self._cache.move_to_end(key) + else: + self._cache[key] = data + self._access_cnt[key] = 0 + self._access_time[key] = time.monotonic() + self._size[key] = self._sizeof(data) + + def _calc_cache_size(self): + return sys.getsizeof(self._cache) + sum(self._size.values()) + + def _cleanup(self): + while self._calc_cache_size() > self._max_memory_usage and self._cache: + least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k]) + least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k]) + largest_key = max(self._cache.keys(), key=lambda k: self._size[k]) + key_to_rm = min([least_frequent_key, least_recent_key, largest_key], + key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k])) + del self._cache[key_to_rm] + del self._access_cnt[key_to_rm] + del self._access_time[key_to_rm] + del self._size[key_to_rm] + + +def is_communication_op(op_name): + # 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等 + # 从wrap文件中读取,先硬编码在文件中 + return (op_name.startswith('Distributed.') and + any(keyword in op_name for keyword in DiffAnalyseConst.COMMUNICATION_KEYWORDS)) + + +def is_ignore_op(op_name): + ignore_keywords = [ + 'Torch.empty', + 'Torch.fill' + ] + return any(keyword in op_name for keyword in ignore_keywords) + + +class DiffAnalyseConst: + COMMUNICATION_KEYWORDS = { + 'send', # send 算子 + 'recv', # recv 算子 + 'broadcast', # broadcast 算子 + 'all_reduce', # all_reduce 算子 + 'reduce', # reduce 算子 + 'all_gather', # all_gather 算子 + 'gather', # gather 算子 + 'isend', # isend 算子 + 'irecv', # irecv 算子 + 'scatter', # scatter 算子 + 'reduce_scatter', # reduce_scatter 算子 + '_reduce_scatter_base', # _reduce_scatter_base 算子 + '_all_gather_base', # _all_gather_base 算子 + 'all_to_all_single', # all_to_all_single 算子 + 'all_to_all', # all_to_all 算子 + 'all_gather_into_tensor', # all_gather_into_tensor 算子 + 'reduce_scatter_tensor', # reduce_scatter_tensor 算子 + 'send_object_list', # send_object_list 算子 + 'recv_object_list' # recv_object_list 算子 + } + P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend', + 'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'} + SRC = 'src' + DST = 'dst' + SRC_GROUP = 'src_group' + DST_GROUP = 'dst_group' + LINK = 'link' + DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC, + 'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC} + OPPOSITE_DIR = {SRC: DST, DST: SRC} + DUMP_FILE = "dump.json" + CONSTRUCT_FILE = "construct.json" + STACK_FILE = "stack.json" + + +def analyze_diff_in_group(nodes_group): + diff_nodes = [] + + def get_compute_ops_from_comm_nodes(comm_nodes): + for comm_node in comm_nodes: + for op_node in comm_node.compute_ops: + op_node.layer = comm_node.layer + diff_nodes.append(op_node) + + def get_comm_ops(comm_nodes): + for node in comm_nodes: + node.data.layer = node.layer + diff_nodes.append(node.data) + + # 先看src或link中input是否有异常 + src_list = list(filter(lambda node: node.type in [DiffAnalyseConst.SRC, DiffAnalyseConst.LINK], nodes_group)) + input_diff_nodes = list(filter(lambda node: node.is_diff, src_list)) + # 如果有异常回溯计算节点找到异常来源 + # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。 + get_compute_ops_from_comm_nodes(input_diff_nodes) + # 筛选入参没问题但出参有问题的通信节点 + output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group)) + get_comm_ops(output_diff_nodes) + return diff_nodes \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/__init__.py b/debug/accuracy_tools/msprobe/mindspore/__init__.py index 5005d6921e6f56d7c869932c7f8a0ccdb019cb67..c36ea84caace7de247ba97f2c8b504627786f9de 100644 --- a/debug/accuracy_tools/msprobe/mindspore/__init__.py +++ b/debug/accuracy_tools/msprobe/mindspore/__init__.py @@ -25,4 +25,4 @@ except ImportError: from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger from msprobe.mindspore.common.utils import seed_all, MsprobeStep, MsprobeInitStep from msprobe.mindspore.monitor.module_hook import TrainerMon -from msprobe.mindspore.dump.graph_tensor_dump import save, save_grad \ No newline at end of file +from msprobe.mindspore.dump.graph_tensor_dump import save, save_grad, step \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py index cb561b1981eb7235c6322456758b456584fc251e..5ae89a605c97dac1d53ca8c15baa4f4fcf2f8226 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/common_dir_compare.py @@ -134,21 +134,31 @@ def find_npy_files(directory): for root, _, files in os.walk(directory): for file in files: if file.endswith(".npy"): - # 分割文件名并去掉最后两个元素 - file_name = file.split('_') - if len(file_name) < 2: - continue - key = '_'.join(file_name[:-2]) - # 文件的完整路径 - value = os.path.join(root, file) - # 添加到字典中 - if not npy_files_dict.get(key): - npy_files_dict[key] = [] - npy_files_dict[key].append(value) + file_name = file.split('.npy') + key = None + if '_' in file_name: + # 分割文件名并去掉最后两个元素 + file_name = file.split('_') + if len(file_name) < 2: + continue + key = '_'.join(file_name[:-2]) + elif '.' in file_name: + file_ele = file_name.split('.') + if len(file_ele) < 2: + continue + key = '_'.join(file_ele[0], file_ele[1]) + # 文件的完整路径 + if key: + value = os.path.join(root, file) + # 添加到字典中 + if not npy_files_dict.get(key): + npy_files_dict[key] = [] + npy_files_dict[key].append(value) return npy_files_dict def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None): + result_dict = {} for k, npu_file_list in npu_file_dict.items(): bench_file_list = bench_file_dict.get(k) if not bench_file_list and k in name_map_dict: @@ -156,7 +166,6 @@ def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None): bench_length = len(bench_file_list) if not (bench_file_list and bench_length): continue - result_dict = {} for i, npu_file in enumerate(npu_file_list): if i >= bench_length: break diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py index 7b3f249e7e7065d52046aa6991a9d8553bb230d6..9d6eedb5be669f60037dcba72bf7a2c02a56112e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py @@ -16,7 +16,8 @@ import os from collections import OrderedDict import mindspore as ms - +from mindspore import hal, ops, Tensor +from mindspore.ops.primitive import _run_op def _iterate_items(data): if isinstance(data, (dict, OrderedDict)): @@ -121,3 +122,9 @@ def save_grad(save_dir, name, data): dump_dir = generate_dump_dir(save_dir) suffix_name = name + '_grad' return _SaveGradCell(dump_dir, suffix_name)(data) + +def step(): + hal.synchronize() + temp_tensor = Tensor([1], dtype=ms.float32) + step_flag = "" + _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor)) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/msprobe.py b/debug/accuracy_tools/msprobe/msprobe.py index 221fb1da81a8562c6852fadb6927496779917fba..14bb3f90a307f06a52da8675b203d0b3d57bb559 100644 --- a/debug/accuracy_tools/msprobe/msprobe.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -55,6 +55,7 @@ def main(): merge_result_parser = subparsers.add_parser('merge_result') config_checking_parser = subparsers.add_parser('config_check') nan_analyze_parser = subparsers.add_parser('nan_analyze') + diff_analyze_parser = subparsers.add_parser('diff_analyze') _config_checking_parser(config_checking_parser) _compare_parser(compare_cmd_parser) _merge_result_parser(merge_result_parser) @@ -77,6 +78,7 @@ def main(): from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \ _run_operator_generate_commond from msprobe.nan_analyze.analyzer import _nan_analyze_parser, _run_nan_analyze + from msprobe.find_first.analyzer import _diff_analyze_parser, _run_diff_analyze _run_ut_parser(run_ut_cmd_parser) _run_ut_parser(multi_run_ut_cmd_parser) @@ -87,6 +89,8 @@ def main(): _pt_graph_service_parser(graph_service_cmd_parser) _op_generator_parser(op_generate_cmd_parser) _nan_analyze_parser(nan_analyze_parser) + _diff_analyze_parser(diff_analyze_parser) + elif framework_args.framework == Const.MS_FRAMEWORK: from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command @@ -133,6 +137,9 @@ def main(): _run_config_checking_command(args) elif sys.argv[3] == "nan_analyze": _run_nan_analyze(args) + elif sys.argv[3] == "diff_analyze": + _run_diff_analyze(args) + else: if not is_module_available(Const.MS_FRAMEWORK): logger.error("MindSpore does not exist, please install MindSpore library")