From a599e5e1f1e43dcb4b5394d3fc4e39991737b50f Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 09:15:48 +0800 Subject: [PATCH 01/18] version 1.0 --- .../msprobe/find_first/__init__.py | 0 .../msprobe/find_first/analyzer.py | 189 ++++++++++++++++++ .../msprobe/find_first/graph.py | 177 ++++++++++++++++ .../msprobe/find_first/utils.py | 24 +++ .../mindspore/monitor/optimizer_collect.py | 17 -- 5 files changed, 390 insertions(+), 17 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/find_first/__init__.py create mode 100644 debug/accuracy_tools/msprobe/find_first/analyzer.py create mode 100644 debug/accuracy_tools/msprobe/find_first/graph.py create mode 100644 debug/accuracy_tools/msprobe/find_first/utils.py diff --git a/debug/accuracy_tools/msprobe/find_first/__init__.py b/debug/accuracy_tools/msprobe/find_first/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py new file mode 100644 index 000000000..a4c1ab691 --- /dev/null +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -0,0 +1,189 @@ +import time +from collections import defaultdict +import os +from itertools import dropwhile, chain + +from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const +from msprobe.nan_analyze.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, NanAnalyseConst, + analyze_anomaly_in_group) +from msprobe.find_first.graph import DataNode, CommunicationNode + + +class NanAnalyzer: + def __init__(self, input_path, output_path): + self._input_path = input_path + self._output_path = output_path + self._paths = {} + self._resolve_input_path() + self._diff_nodes = [] # 记录所有异常节点 + self._cache = FileCache() + self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id + self._after_comm_anomalies = {} # 记录各rank下发生在通信节点之后的异常计算节点 + self._rank_comm_nodes_dict = {} # 记录各rank的通信节点 + + def analyze(self): + for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]: + analyze_func() + if self._diff_nodes: + self._gen_analyze_info() + return + logger.info('Cannot find any anomaly node, no need to generate analyze file.') + + """ + 这里需要生成stack,但是直接用dict中自带就行,在op_items.NPU_Stack_Info中 + """ + def _resolve_input_path(self): + contents = os.listdir(self._input_path) + for path in contents: + if not path.startswith('rank'): + continue + rank_str = path.strip('rank') + if not rank_str: + rank = 0 + elif not rank_str.isdigit(): + continue + else: + rank = int(rank_str) + dump_path = os.path.join(self._input_path, path, NanAnalyseConst.DUMP_FILE) + construct_path = os.path.join(self._input_path, path, NanAnalyseConst.CONSTRUCT_FILE) + stack_path = os.path.join(self._input_path, path, NanAnalyseConst.STACK_FILE) + self._paths[rank] = RankPath(rank, dump_path, construct_path, stack_path) + + def _pre_analyze(self): + logger.info('Start searching anomaly node before communication.') + for path in self._paths.values(): + dump_data = self._cache.load_json(path.dump_path) + if not dump_data: + logger.warning(f'Rank {path.rank} has no dump data!') + continue + for op_name, op_data in dump_data.items(): + if is_communication_op(op_name): + self._first_comm_nodes[path.rank] = op_name + break + data_node = DataNode(op_name, path.rank, op_data) + if data_node.is_diff: + self._diff_nodes.append(data_node) + break + + def _analyze(self): + logger.info('Start searching anomaly node during communication.') + self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths} + self._connect_comm_nodes() + self._pruning() + self._search_first_diff() + + def _post_analyze(self): + logger.info('Start searching anomaly node after communication.') + for nodes in self._after_comm_anomalies.values(): + if nodes: + self._diff_nodes.append(nodes[0]) + + def _connect_comm_nodes(self): + searched_ranks = set() + for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]: + searched_ranks.add(rank) + seen_nodes = set() + for cur_node in nodes.values(): + conn_info = cur_node.find_connected_nodes() + if not conn_info.get('ranks'): + conn_info['ranks'] = self._rank_comm_nodes_dict.keys() + if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes): + logger.info(f'Cannot find connected communication node for "{cur_node.node_id}".') + + def _analyze_comm_nodes(self, rank): + path = self._paths[rank] + data = self._cache.load_json(path.dump_path) + communication_nodes = {} + if rank not in self._first_comm_nodes: # 此rank没有通信节点 + return communication_nodes + last_node_id = None # 记录上一个通信节点的node_id + compute_ops = [] # 记录两个通信节点之间的计算节点 + sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数 + for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data): + node_id = f'{rank}.{op_name}' + op_data = data[op_name] + if is_communication_op(op_name): + comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer), + compute_ops=compute_ops) + if last_node_id: + communication_nodes.get(last_node_id).add_next(comm_node) + communication_nodes[node_id] = comm_node + last_node_id = node_id + compute_ops = [] + sub_layer = 0 + elif not is_ignore_op(op_name): + data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer) + if data_node.is_diff: + compute_ops.append(data_node) + sub_layer += 1 + if compute_ops: + self._after_comm_anomalies[rank] = compute_ops + return communication_nodes + + def _pruning(self): + deleted_node_id = [] + for nodes in self._rank_comm_nodes_dict.values(): + for node_id in list(nodes.keys()): + node = nodes[node_id] + if node.is_diff or node.compute_ops: + continue + deleted_node_id.append(node_id) + node.delete() + del nodes[node_id] + logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]') + + def _search_first_diff(self): + nodes_queues = [] + for comm_nodes in self._rank_comm_nodes_dict.values(): + nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer)) + seen_nodes = set() + + def get_next_node(node_list): + while node_list: + next_node = node_list.pop(0) + if next_node.node_id not in seen_nodes: + return next_node + return None + + def find_all_members(ori_node): + ids = get_relative_ids(ori_node) + id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids])) + while id_queue: + new_id = id_queue.pop(0) + ids.add(new_id) + id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids)) + return ids + + def get_relative_ids(ori_node): + if not ori_node: + return set() + return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() | + ori_node.dst_nodes.keys()) + + while any(nodes_queues): + groups = [] + all_ids_in_groups = set() + for nodes in nodes_queues: + node = get_next_node(nodes) + if not node: + continue + if not groups or node.node_id in all_ids_in_groups: + new_group = find_all_members(node) + groups.append(new_group) + all_ids_in_groups.update(new_group) + for group in groups: + seen_nodes.update(group) + self._anomaly_nodes.extend(analyze_anomaly_in_group([self._get_node_by_id(n_id) for n_id in group])) + if self._anomaly_nodes: + self._anomaly_nodes = [min(self._anomaly_nodes, key=lambda x: (x.layer, x.sub_layer))] + return + + def _get_node_by_id(self, node_id): + splits = node_id.split(Const.SEP, 1) + if len(splits) < 2 or not splits[0].isdigit(): + logger.error(f'invalid node_id {node_id}') + raise RuntimeError(f'invalid node_id {node_id}') + rank = int(splits[0]) + return self._rank_comm_nodes_dict.get(rank, {}).get(node_id) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py new file mode 100644 index 000000000..173fcb3df --- /dev/null +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst + + +@dataclass +class DataNode: + op_name: str + rank: int + inputs: list + input_args: list + input_kwargs: dict + outputs: dict + layer: int = 0 # 和communication_node的layer保持一致 + sub_layer: int = 0 # 调用顺序,越小表示越先调用 + + def __init__(self, op_name, rank, op_data, **kwargs): + self.op_name = op_name + self.rank = rank + self.inputs = op_data.get(Const.INPUT, []) + self.input_args = op_data.get(Const.INPUT_ARGS, []) + self.input_kwargs = op_data.get(Const.INPUT_KWARGS, {}) + self.outputs = op_data.get(Const.OUTPUT, {}) + self.sub_layer = kwargs.get('sub_layer', 0) + self.is_diff = op_data.get("is_same", False) + + @staticmethod + def find_complete_construct(construct_info, op_name): + construct = [op_name] + seen = set(op_name) + while True: + op_name = construct_info.get(op_name) + if not op_name or op_name in seen: + return construct + construct.insert(0, op_name) + seen.add(op_name) + + def find_stack(self, stack_info): + for item in stack_info.values(): + if len(item) >= 2 and self.op_name in item[0]: + return item[1] + return {} + + def gen_node_info(self, path: RankPath): + cache = FileCache() + construct = cache.load_json(path.construct_path) + stack = cache.load_json(path.stack_path) + if Const.FORWARD in self.op_name: + data_info_list = {Const.INPUT_ARGS: self.input_args, Const.INPUT_KWARGS: self.input_kwargs, + Const.OUTPUT: self.outputs} + else: + data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} + return {'op_name': self.op_name, + 'data_info': data_info_list, + 'construct_info': self.find_complete_construct(construct, self.op_name), + 'stack_info': self.find_stack(stack)} + + +class CommunicationNode: + def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs): + self.node_id = node_id + self.rank = rank + self.data = data + self.is_diff = data.get('is_same', False) + self.layer = layer + op_name_split = self.data.op_name.split(Const.SEP) + if len(op_name_split) < 4: + logger.error(f'invalid op_name: {self.data.op_name}') + raise RuntimeError(f'invalid op_name: {self.data.op_name}') + self.api = op_name_split[1] + self.call_cnt = op_name_split[2] + self.pre_node = kwargs.get('pre_node') + self.link_nodes = kwargs.get('link_nodes', {}) + self.dst_nodes = kwargs.get('dst_nodes', {}) + self.src_nodes = kwargs.get('src_nodes', {}) + self.next_nodes = kwargs.get('next_nodes', {}) + self.compute_ops = kwargs.get('compute_ops', []) + self.type = self._resolve_type() + self.connected = False + + def add_next(self, node): + self.next_nodes[node.node_id] = node + node.pre_node = self + node.layer = self.layer + 1 + node.data.layer = node.layer + + def add_link(self, node): + self.link_nodes[node.node_id] = node + node.link_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def add_dst(self, node): + self.dst_nodes[node.node_id] = node + node.src_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def delete(self): + for node in self.next_nodes.values(): + node.pre_node = None + for node in self.dst_nodes.values(): + node.src_nodes.pop(self.node_id) + for node in self.src_nodes.values(): + node.dst_nodes.pop(self.node_id) + for node in self.link_nodes.values(): + node.link_nodes.pop(self.node_id) + if self.pre_node: + self.pre_node.next_nodes.pop(self.node_id) + + def find_connected_nodes(self): + """ + 根据 api/类型/入参/调用次数 确定相连接的node的op_name + """ + tar_api = NanAnalyseConst.P2P_API_MAPPING.get(self.api, self.api) + ranks = set() + for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]: + if dst in self.data.input_kwargs: + dst_value = self.data.input_kwargs.get(dst) + if dst_value: + ranks.add(dst_value.get('value')) + break + for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]: + if src in self.data.input_kwargs: + src_value = self.data.input_kwargs.get(src) + if src_value: + ranks.add(src_value.get('value')) + break + if not ranks: + for item in self.data.input_args: + if isinstance(item, dict) and item.get(Const.TYPE) == 'int': + ranks.add(item.get('value')) + group = self.data.input_kwargs.get('group') + if group: + ranks.update(group.get('group_ranks')) + return {'ranks': ranks, 'api': f'Distributed.{tar_api}', + 'type': NanAnalyseConst.OPPOSITE_DIR.get(self.type, NanAnalyseConst.LINK)} + + def _resolve_type(self): + for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]: + if src in self.data.input_kwargs and self.data.input_kwargs[src]: + if self.data.input_kwargs[src].get('value') == self.rank: + return NanAnalyseConst.SRC + else: + return NanAnalyseConst.DST + for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]: + if dst in self.data.input_kwargs and self.data.input_kwargs[dst]: + if self.data.input_kwargs[dst].get('value') == self.rank: + return NanAnalyseConst.DST + else: + return NanAnalyseConst.SRC + if self.api in NanAnalyseConst.DIRECTED_API: + for item in self.data.input_args: + if item.get(Const.TYPE) == 'int': + node_type = NanAnalyseConst.DIRECTED_API[self.api] + return node_type if item.get('value') == self.rank else NanAnalyseConst.OPPOSITE_DIR[node_type] + return NanAnalyseConst.LINK diff --git a/debug/accuracy_tools/msprobe/find_first/utils.py b/debug/accuracy_tools/msprobe/find_first/utils.py new file mode 100644 index 000000000..5cb230e9c --- /dev/null +++ b/debug/accuracy_tools/msprobe/find_first/utils.py @@ -0,0 +1,24 @@ +def analyze_anomaly_in_group(nodes_group): + anomaly_nodes = [] + + def get_compute_ops_from_comm_nodes(comm_nodes): + for comm_node in comm_nodes: + for op_node in comm_node.compute_ops: + op_node.layer = comm_node.layer + anomaly_nodes.append(op_node) + + def get_comm_ops(comm_nodes): + for node in comm_nodes: + node.data.layer = node.layer + anomaly_nodes.append(node.data) + + # 先看src或link中input是否有异常 + src_list = list(filter(lambda node: node.type in [NanAnalyseConst.SRC, NanAnalyseConst.LINK], nodes_group)) + input_anomaly_nodes = list(filter(lambda node: node.is_diff, src_list)) + # 如果有异常回溯计算节点找到异常来源 + # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。 + get_compute_ops_from_comm_nodes(input_anomaly_nodes) + # 筛选入参没问题但出参有问题的通信节点 + output_anomaly_nodes = list(filter(lambda node: node.data.is_diff, nodes_group)) + get_comm_ops(output_anomaly_nodes) + return anomaly_nodes \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py index 7efbb4590..90bb0dd39 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py @@ -1,17 +1,3 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from abc import abstractmethod from mindspore import mint, ops @@ -105,9 +91,6 @@ class OptimizerMon(object): else: logger.warning(f"step of {name} is None, maybe something wrong happened.") continue - if exp_avg is None or exp_avg_sq is None: - logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.") - continue exp_avg_hat = exp_avg / (1 - self.optim.defaults['betas'][0] ** step) exp_avg_sq_hat = exp_avg_sq / (1 - self.optim.defaults['betas'][1] ** step) update_dict[name] = exp_avg_hat / (mint.sqrt(exp_avg_sq_hat) + self.optim.defaults['eps']) -- Gitee From d561d2cd4a725b12887336a60704cc1601a25f72 Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 09:43:52 +0800 Subject: [PATCH 02/18] version 2.0 --- .../msprobe/find_first/analyzer.py | 4 +-- .../msprobe/find_first/utils.py | 35 ++++++++++++++----- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py index a4c1ab691..56934dde3 100644 --- a/debug/accuracy_tools/msprobe/find_first/analyzer.py +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -6,12 +6,12 @@ from itertools import dropwhile, chain from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir from msprobe.core.common.log import logger from msprobe.core.common.const import Const -from msprobe.nan_analyze.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, NanAnalyseConst, +from msprobe.find_first.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, NanAnalyseConst, analyze_anomaly_in_group) from msprobe.find_first.graph import DataNode, CommunicationNode -class NanAnalyzer: +class DiffAnalyzer: def __init__(self, input_path, output_path): self._input_path = input_path self._output_path = output_path diff --git a/debug/accuracy_tools/msprobe/find_first/utils.py b/debug/accuracy_tools/msprobe/find_first/utils.py index 5cb230e9c..59135b08f 100644 --- a/debug/accuracy_tools/msprobe/find_first/utils.py +++ b/debug/accuracy_tools/msprobe/find_first/utils.py @@ -1,24 +1,41 @@ -def analyze_anomaly_in_group(nodes_group): - anomaly_nodes = [] +@dataclass +class RankPath: + rank: int + dump_path: str + construct_path: str + stack_path: str + + def __init__(self, rank, dump_path, construct_path, stack_path): + self.rank = rank + check_file_or_directory_path(dump_path) + self.dump_path = dump_path + check_file_or_directory_path(construct_path) + self.construct_path = construct_path + check_file_or_directory_path(stack_path) + self.stack_path = stack_path + + +def analyze_diff_in_group(nodes_group): + diff_nodes = [] def get_compute_ops_from_comm_nodes(comm_nodes): for comm_node in comm_nodes: for op_node in comm_node.compute_ops: op_node.layer = comm_node.layer - anomaly_nodes.append(op_node) + diff_nodes.append(op_node) def get_comm_ops(comm_nodes): for node in comm_nodes: node.data.layer = node.layer - anomaly_nodes.append(node.data) + diff_nodes.append(node.data) # 先看src或link中input是否有异常 src_list = list(filter(lambda node: node.type in [NanAnalyseConst.SRC, NanAnalyseConst.LINK], nodes_group)) - input_anomaly_nodes = list(filter(lambda node: node.is_diff, src_list)) + input_diff_nodes = list(filter(lambda node: node.is_diff, src_list)) # 如果有异常回溯计算节点找到异常来源 # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。 - get_compute_ops_from_comm_nodes(input_anomaly_nodes) + get_compute_ops_from_comm_nodes(input_diff_nodes) # 筛选入参没问题但出参有问题的通信节点 - output_anomaly_nodes = list(filter(lambda node: node.data.is_diff, nodes_group)) - get_comm_ops(output_anomaly_nodes) - return anomaly_nodes \ No newline at end of file + output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group)) + get_comm_ops(output_diff_nodes) + return diff_nodes \ No newline at end of file -- Gitee From 271feb09f89a507892cf6e0656fb7025439c1dca Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 14:08:03 +0800 Subject: [PATCH 03/18] version 3.0 --- .../msprobe/find_first/analyzer.py | 43 ++++-- .../msprobe/find_first/utils.py | 130 +++++++++++++++++- debug/accuracy_tools/msprobe/msprobe.py | 6 +- 3 files changed, 168 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py index 56934dde3..ffed3134e 100644 --- a/debug/accuracy_tools/msprobe/find_first/analyzer.py +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -6,8 +6,8 @@ from itertools import dropwhile, chain from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir from msprobe.core.common.log import logger from msprobe.core.common.const import Const -from msprobe.find_first.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, NanAnalyseConst, - analyze_anomaly_in_group) +from msprobe.find_first.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, DiffAnalyseConst, + analyze_diff_in_group) from msprobe.find_first.graph import DataNode, CommunicationNode @@ -46,9 +46,9 @@ class DiffAnalyzer: continue else: rank = int(rank_str) - dump_path = os.path.join(self._input_path, path, NanAnalyseConst.DUMP_FILE) - construct_path = os.path.join(self._input_path, path, NanAnalyseConst.CONSTRUCT_FILE) - stack_path = os.path.join(self._input_path, path, NanAnalyseConst.STACK_FILE) + dump_path = os.path.join(self._input_path, path, DiffAnalyseConst.DUMP_FILE) + construct_path = os.path.join(self._input_path, path, DiffAnalyseConst.CONSTRUCT_FILE) + stack_path = os.path.join(self._input_path, path, DiffAnalyseConst.STACK_FILE) self._paths[rank] = RankPath(rank, dump_path, construct_path, stack_path) def _pre_analyze(self): @@ -175,9 +175,9 @@ class DiffAnalyzer: all_ids_in_groups.update(new_group) for group in groups: seen_nodes.update(group) - self._anomaly_nodes.extend(analyze_anomaly_in_group([self._get_node_by_id(n_id) for n_id in group])) - if self._anomaly_nodes: - self._anomaly_nodes = [min(self._anomaly_nodes, key=lambda x: (x.layer, x.sub_layer))] + self._diff_nodes.extend(analyze_diff_in_group([self._get_node_by_id(n_id) for n_id in group])) + if self._diff_nodes: + self._diff_nodes = [min(self._diff_nodes, key=lambda x: (x.layer, x.sub_layer))] return def _get_node_by_id(self, node_id): @@ -186,4 +186,29 @@ class DiffAnalyzer: logger.error(f'invalid node_id {node_id}') raise RuntimeError(f'invalid node_id {node_id}') rank = int(splits[0]) - return self._rank_comm_nodes_dict.get(rank, {}).get(node_id) \ No newline at end of file + return self._rank_comm_nodes_dict.get(rank, {}).get(node_id) + + def _gen_analyze_info(self): + if not os.path.exists(self._output_path): + make_dir(self._output_path) + file_name = f'anomaly_analyze_{time.time_ns()}.json' + result_file = os.path.join(self._output_path, file_name) + result_content = defaultdict(list) + for node in self._diff_nodes: + result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank])) + save_json(result_file, result_content, 2) + logger.info(f"The analyze result is saved in: {result_file}") + + +def _diff_analyze_parser(parser): + parser.add_argument("-i", "--input_path", dest="input_path", default="", type=str, + help=" The dump file path, over step level. eg: \"xxx/step_0/\".", + required=True) + parser.add_argument("-o", "--output_path", dest="output_path", default="./output", type=str, + help=" The diff analyze result output file path.", + required=False) + + +def _run_diff_analyze(args): + check_file_or_directory_path(args.input_path, True) + DiffAnalyzer(args.input_path, args.output_path).analyze() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/find_first/utils.py b/debug/accuracy_tools/msprobe/find_first/utils.py index 59135b08f..559a8f363 100644 --- a/debug/accuracy_tools/msprobe/find_first/utils.py +++ b/debug/accuracy_tools/msprobe/find_first/utils.py @@ -15,6 +15,134 @@ class RankPath: self.stack_path = stack_path +class FileCache: + """ + lazy load file + """ + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4 + self._cache = OrderedDict() + self._access_cnt = {} + self._access_time = {} + self._size = {} + + @staticmethod + def _sizeof(obj): + seen = set() + objs = [obj] + size = 0 + while objs: + obj = objs.pop() + obj_id = id(obj) + if obj_id in seen: + continue + seen.add(obj_id) + size += sys.getsizeof(obj) + if isinstance(obj, dict): + objs.extend(obj.keys()) + objs.extend(obj.values()) + elif isinstance(obj, (list, tuple, set, frozenset)): + objs.extend(obj) + return size + + def load_json(self, json_path): + if json_path in self._cache: + self._access_cnt[json_path] += 1 + self._access_time[json_path] = time.monotonic() + self._cache.move_to_end(json_path) + return self._cache[json_path] + self._cleanup() + return self._load(json_path) + + def _load(self, json_path): + data = load_json(json_path) + self._add_to_cache(json_path, data) + return data + + def _add_to_cache(self, key, data): + if key in self._cache: + self._cache.move_to_end(key) + else: + self._cache[key] = data + self._access_cnt[key] = 0 + self._access_time[key] = time.monotonic() + self._size[key] = self._sizeof(data) + + def _calc_cache_size(self): + return sys.getsizeof(self._cache) + sum(self._size.values()) + + def _cleanup(self): + while self._calc_cache_size() > self._max_memory_usage and self._cache: + least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k]) + least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k]) + largest_key = max(self._cache.keys(), key=lambda k: self._size[k]) + key_to_rm = min([least_frequent_key, least_recent_key, largest_key], + key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k])) + del self._cache[key_to_rm] + del self._access_cnt[key_to_rm] + del self._access_time[key_to_rm] + del self._size[key_to_rm] + + +def is_communication_op(op_name): + # 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等 + # 从wrap文件中读取,先硬编码在文件中 + return (op_name.startswith('Distributed.') and + any(keyword in op_name for keyword in DiffAnalyseConst.COMMUNICATION_KEYWORDS)) + + +def is_ignore_op(op_name): + ignore_keywords = [ + 'Torch.empty', + 'Torch.fill' + ] + return any(keyword in op_name for keyword in ignore_keywords) + + +class DiffAnalyseConst: + COMMUNICATION_KEYWORDS = { + 'send', # send 算子 + 'recv', # recv 算子 + 'broadcast', # broadcast 算子 + 'all_reduce', # all_reduce 算子 + 'reduce', # reduce 算子 + 'all_gather', # all_gather 算子 + 'gather', # gather 算子 + 'isend', # isend 算子 + 'irecv', # irecv 算子 + 'scatter', # scatter 算子 + 'reduce_scatter', # reduce_scatter 算子 + '_reduce_scatter_base', # _reduce_scatter_base 算子 + '_all_gather_base', # _all_gather_base 算子 + 'all_to_all_single', # all_to_all_single 算子 + 'all_to_all', # all_to_all 算子 + 'all_gather_into_tensor', # all_gather_into_tensor 算子 + 'reduce_scatter_tensor', # reduce_scatter_tensor 算子 + 'send_object_list', # send_object_list 算子 + 'recv_object_list' # recv_object_list 算子 + } + P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend', + 'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'} + SRC = 'src' + DST = 'dst' + SRC_GROUP = 'src_group' + DST_GROUP = 'dst_group' + LINK = 'link' + DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC, + 'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC} + OPPOSITE_DIR = {SRC: DST, DST: SRC} + DUMP_FILE = "dump.json" + CONSTRUCT_FILE = "construct.json" + STACK_FILE = "stack.json" + + def analyze_diff_in_group(nodes_group): diff_nodes = [] @@ -30,7 +158,7 @@ def analyze_diff_in_group(nodes_group): diff_nodes.append(node.data) # 先看src或link中input是否有异常 - src_list = list(filter(lambda node: node.type in [NanAnalyseConst.SRC, NanAnalyseConst.LINK], nodes_group)) + src_list = list(filter(lambda node: node.type in [DiffAnalyseConst.SRC, DiffAnalyseConst.LINK], nodes_group)) input_diff_nodes = list(filter(lambda node: node.is_diff, src_list)) # 如果有异常回溯计算节点找到异常来源 # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。 diff --git a/debug/accuracy_tools/msprobe/msprobe.py b/debug/accuracy_tools/msprobe/msprobe.py index 221fb1da8..a93330e8f 100644 --- a/debug/accuracy_tools/msprobe/msprobe.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -77,7 +77,7 @@ def main(): from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \ _run_operator_generate_commond from msprobe.nan_analyze.analyzer import _nan_analyze_parser, _run_nan_analyze - + from msprobe.find_first.analyzer import _diff_analyze_parser, _run_diff_analyze _run_ut_parser(run_ut_cmd_parser) _run_ut_parser(multi_run_ut_cmd_parser) multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, @@ -87,6 +87,8 @@ def main(): _pt_graph_service_parser(graph_service_cmd_parser) _op_generator_parser(op_generate_cmd_parser) _nan_analyze_parser(nan_analyze_parser) + _diff_analyze_parser(diff_analyze_parser) + elif framework_args.framework == Const.MS_FRAMEWORK: from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command @@ -133,6 +135,8 @@ def main(): _run_config_checking_command(args) elif sys.argv[3] == "nan_analyze": _run_nan_analyze(args) + elif sys.argv[3] == "diff_analyze": + _run_diff_analyze(args) else: if not is_module_available(Const.MS_FRAMEWORK): logger.error("MindSpore does not exist, please install MindSpore library") -- Gitee From c1de083fdfd14189a5e827c885116e5c42d85d20 Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 14:15:27 +0800 Subject: [PATCH 04/18] version 4.0 --- debug/accuracy_tools/msprobe/find_first/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/debug/accuracy_tools/msprobe/find_first/utils.py b/debug/accuracy_tools/msprobe/find_first/utils.py index 559a8f363..cef9d7316 100644 --- a/debug/accuracy_tools/msprobe/find_first/utils.py +++ b/debug/accuracy_tools/msprobe/find_first/utils.py @@ -1,3 +1,11 @@ +from collections import OrderedDict +from dataclasses import dataclass +import sys +import time +import psutil + +from msprobe.core.common.file_utils import check_file_or_directory_path, load_json + @dataclass class RankPath: rank: int -- Gitee From a81c3e89ff74586bff3a30b87d17db8fca2aa73b Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 14:16:54 +0800 Subject: [PATCH 05/18] version 5.0 --- debug/accuracy_tools/msprobe/msprobe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/debug/accuracy_tools/msprobe/msprobe.py b/debug/accuracy_tools/msprobe/msprobe.py index a93330e8f..e7780ef5f 100644 --- a/debug/accuracy_tools/msprobe/msprobe.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -55,6 +55,7 @@ def main(): merge_result_parser = subparsers.add_parser('merge_result') config_checking_parser = subparsers.add_parser('config_check') nan_analyze_parser = subparsers.add_parser('nan_analyze') + diff_analyze_parser = subparsers.add_parser('diff_analyze') _config_checking_parser(config_checking_parser) _compare_parser(compare_cmd_parser) _merge_result_parser(merge_result_parser) -- Gitee From 1203dd489f044b46b0c206229c7967975a44e414 Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 14:18:19 +0800 Subject: [PATCH 06/18] version 6.0 --- debug/accuracy_tools/msprobe/find_first/analyzer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py index ffed3134e..2b8cf89e9 100644 --- a/debug/accuracy_tools/msprobe/find_first/analyzer.py +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -29,7 +29,7 @@ class DiffAnalyzer: if self._diff_nodes: self._gen_analyze_info() return - logger.info('Cannot find any anomaly node, no need to generate analyze file.') + logger.info('Cannot find any diff node, no need to generate analyze file.') """ 这里需要生成stack,但是直接用dict中自带就行,在op_items.NPU_Stack_Info中 @@ -52,7 +52,7 @@ class DiffAnalyzer: self._paths[rank] = RankPath(rank, dump_path, construct_path, stack_path) def _pre_analyze(self): - logger.info('Start searching anomaly node before communication.') + logger.info('Start searching diff node before communication.') for path in self._paths.values(): dump_data = self._cache.load_json(path.dump_path) if not dump_data: @@ -68,14 +68,14 @@ class DiffAnalyzer: break def _analyze(self): - logger.info('Start searching anomaly node during communication.') + logger.info('Start searching diff node during communication.') self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths} self._connect_comm_nodes() self._pruning() self._search_first_diff() def _post_analyze(self): - logger.info('Start searching anomaly node after communication.') + logger.info('Start searching diff node after communication.') for nodes in self._after_comm_anomalies.values(): if nodes: self._diff_nodes.append(nodes[0]) @@ -191,7 +191,7 @@ class DiffAnalyzer: def _gen_analyze_info(self): if not os.path.exists(self._output_path): make_dir(self._output_path) - file_name = f'anomaly_analyze_{time.time_ns()}.json' + file_name = f'diff_analyze_{time.time_ns()}.json' result_file = os.path.join(self._output_path, file_name) result_content = defaultdict(list) for node in self._diff_nodes: -- Gitee From 5b91b748b4d8f3b8e5675b24465ee79d113dcf99 Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 14:28:55 +0800 Subject: [PATCH 07/18] version 7.0 --- debug/accuracy_tools/msprobe/find_first/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index 173fcb3df..cbd9476c7 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -38,7 +38,7 @@ class DataNode: self.input_kwargs = op_data.get(Const.INPUT_KWARGS, {}) self.outputs = op_data.get(Const.OUTPUT, {}) self.sub_layer = kwargs.get('sub_layer', 0) - self.is_diff = op_data.get("is_same", False) + self.is_diff = not op_data.get("is_same", True) @staticmethod def find_complete_construct(construct_info, op_name): -- Gitee From 26076c35d589bb177561cb50faba9cb0245eba1a Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 14:35:39 +0800 Subject: [PATCH 08/18] version 8.0 --- debug/accuracy_tools/msprobe/find_first/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index cbd9476c7..1ab211372 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -77,7 +77,7 @@ class CommunicationNode: self.node_id = node_id self.rank = rank self.data = data - self.is_diff = data.get('is_same', False) + self.is_diff = data.is_diff self.layer = layer op_name_split = self.data.op_name.split(Const.SEP) if len(op_name_split) < 4: -- Gitee From 321c0e55c848cbec22be9e62d41ce4a5ba64b05d Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 14:41:47 +0800 Subject: [PATCH 09/18] version 9.0 --- debug/accuracy_tools/msprobe/find_first/analyzer.py | 4 +--- debug/accuracy_tools/msprobe/find_first/graph.py | 7 +++---- debug/accuracy_tools/msprobe/find_first/utils.py | 7 +------ 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py index 2b8cf89e9..9ea5f48e7 100644 --- a/debug/accuracy_tools/msprobe/find_first/analyzer.py +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -47,9 +47,7 @@ class DiffAnalyzer: else: rank = int(rank_str) dump_path = os.path.join(self._input_path, path, DiffAnalyseConst.DUMP_FILE) - construct_path = os.path.join(self._input_path, path, DiffAnalyseConst.CONSTRUCT_FILE) - stack_path = os.path.join(self._input_path, path, DiffAnalyseConst.STACK_FILE) - self._paths[rank] = RankPath(rank, dump_path, construct_path, stack_path) + self._paths[rank] = RankPath(rank, dump_path) def _pre_analyze(self): logger.info('Start searching diff node before communication.') diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index 1ab211372..dd7419e02 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -39,6 +39,7 @@ class DataNode: self.outputs = op_data.get(Const.OUTPUT, {}) self.sub_layer = kwargs.get('sub_layer', 0) self.is_diff = not op_data.get("is_same", True) + self.stack = op_data.get("op_items")[0].get("NPU_Stack_Info") @staticmethod def find_complete_construct(construct_info, op_name): @@ -59,8 +60,6 @@ class DataNode: def gen_node_info(self, path: RankPath): cache = FileCache() - construct = cache.load_json(path.construct_path) - stack = cache.load_json(path.stack_path) if Const.FORWARD in self.op_name: data_info_list = {Const.INPUT_ARGS: self.input_args, Const.INPUT_KWARGS: self.input_kwargs, Const.OUTPUT: self.outputs} @@ -68,8 +67,8 @@ class DataNode: data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} return {'op_name': self.op_name, 'data_info': data_info_list, - 'construct_info': self.find_complete_construct(construct, self.op_name), - 'stack_info': self.find_stack(stack)} + 'construct_info': "None", + 'stack_info': self.stack} class CommunicationNode: diff --git a/debug/accuracy_tools/msprobe/find_first/utils.py b/debug/accuracy_tools/msprobe/find_first/utils.py index cef9d7316..0e8c30760 100644 --- a/debug/accuracy_tools/msprobe/find_first/utils.py +++ b/debug/accuracy_tools/msprobe/find_first/utils.py @@ -13,15 +13,10 @@ class RankPath: construct_path: str stack_path: str - def __init__(self, rank, dump_path, construct_path, stack_path): + def __init__(self, rank, dump_path): self.rank = rank check_file_or_directory_path(dump_path) self.dump_path = dump_path - check_file_or_directory_path(construct_path) - self.construct_path = construct_path - check_file_or_directory_path(stack_path) - self.stack_path = stack_path - class FileCache: """ -- Gitee From a11176e9d9ca160c17344640238fdc639cbcfc14 Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 18 Jun 2025 15:30:47 +0800 Subject: [PATCH 10/18] version 10.0 --- .../msprobe/find_first/analyzer.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py index 9ea5f48e7..fa75e3242 100644 --- a/debug/accuracy_tools/msprobe/find_first/analyzer.py +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -90,6 +90,35 @@ class DiffAnalyzer: if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes): logger.info(f'Cannot find connected communication node for "{cur_node.node_id}".') + def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes): + def connect(): + seen_nodes.add(search_node.node_id) + if search_node.type == DiffAnalyseConst.DST: + cur_node.add_dst(search_node) + elif search_node.type == DiffAnalyseConst.SRC: + search_node.layer = cur_node.layer + search_node.add_dst(cur_node) + else: + cur_node.add_link(search_node) + + found = cur_node.connected + for connected_rank in conn_info['ranks']: + if connected_rank in searched_ranks: + continue + tar_id_prefix = f'{connected_rank}.{conn_info["api"]}' + for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items(): + if search_id in seen_nodes: + continue + if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')): + continue + search_conn_ranks = search_node.find_connected_nodes().get('ranks') + if ((not search_conn_ranks and search_node.api not in DiffAnalyseConst.DIRECTED_API) or + cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank + connect() + found = True + break + return found + def _analyze_comm_nodes(self, rank): path = self._paths[rank] data = self._cache.load_json(path.dump_path) -- Gitee From be4dea6fc258249941d1e8d04863a9e33cc23b9e Mon Sep 17 00:00:00 2001 From: TAJh Date: Thu, 19 Jun 2025 16:48:34 +0800 Subject: [PATCH 11/18] version 11.0 --- .../msprobe/find_first/graph.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index dd7419e02..080c5b1ab 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from msprobe.core.common.const import Const from msprobe.core.common.log import logger -from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst +from msprobe.find_first.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, DiffAnalyseConst @dataclass @@ -131,15 +131,15 @@ class CommunicationNode: """ 根据 api/类型/入参/调用次数 确定相连接的node的op_name """ - tar_api = NanAnalyseConst.P2P_API_MAPPING.get(self.api, self.api) + tar_api = DiffAnalyseConst.P2P_API_MAPPING.get(self.api, self.api) ranks = set() - for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]: - if dst in self.data.input_kwargs: + for dst in [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP]: + if dst in self.data.op_name: dst_value = self.data.input_kwargs.get(dst) if dst_value: ranks.add(dst_value.get('value')) break - for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]: + for src in [DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]: if src in self.data.input_kwargs: src_value = self.data.input_kwargs.get(src) if src_value: @@ -153,24 +153,24 @@ class CommunicationNode: if group: ranks.update(group.get('group_ranks')) return {'ranks': ranks, 'api': f'Distributed.{tar_api}', - 'type': NanAnalyseConst.OPPOSITE_DIR.get(self.type, NanAnalyseConst.LINK)} + 'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)} def _resolve_type(self): - for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]: + for src in [DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]: if src in self.data.input_kwargs and self.data.input_kwargs[src]: if self.data.input_kwargs[src].get('value') == self.rank: - return NanAnalyseConst.SRC + return DiffAnalyseConst.SRC else: - return NanAnalyseConst.DST - for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]: + return DiffAnalyseConst.DST + for dst in [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP]: if dst in self.data.input_kwargs and self.data.input_kwargs[dst]: if self.data.input_kwargs[dst].get('value') == self.rank: - return NanAnalyseConst.DST + return DiffAnalyseConst.DST else: - return NanAnalyseConst.SRC - if self.api in NanAnalyseConst.DIRECTED_API: + return DiffAnalyseConst.SRC + if self.api in DiffAnalyseConst.DIRECTED_API: for item in self.data.input_args: if item.get(Const.TYPE) == 'int': - node_type = NanAnalyseConst.DIRECTED_API[self.api] - return node_type if item.get('value') == self.rank else NanAnalyseConst.OPPOSITE_DIR[node_type] - return NanAnalyseConst.LINK + node_type = DiffAnalyseConst.DIRECTED_API[self.api] + return node_type if item.get('value') == self.rank else DiffAnalyseConst.OPPOSITE_DIR[node_type] + return DiffAnalyseConst.LINK -- Gitee From d51d9807e0ab1021bb54ae029700e8c7e924ca76 Mon Sep 17 00:00:00 2001 From: TAJh Date: Sat, 21 Jun 2025 15:24:22 +0800 Subject: [PATCH 12/18] version 12.0 --- .../msprobe/find_first/graph.py | 108 +++++++----------- 1 file changed, 42 insertions(+), 66 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index 080c5b1ab..d20b143b4 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -23,52 +23,48 @@ from msprobe.find_first.utils import FileCache, RankPath, is_ignore_op, check_it class DataNode: op_name: str rank: int - inputs: list - input_args: list - input_kwargs: dict + inputs: dict outputs: dict + op_data: list + is_same: bool layer: int = 0 # 和communication_node的layer保持一致 sub_layer: int = 0 # 调用顺序,越小表示越先调用 def __init__(self, op_name, rank, op_data, **kwargs): self.op_name = op_name self.rank = rank - self.inputs = op_data.get(Const.INPUT, []) - self.input_args = op_data.get(Const.INPUT_ARGS, []) - self.input_kwargs = op_data.get(Const.INPUT_KWARGS, {}) - self.outputs = op_data.get(Const.OUTPUT, {}) + self.stack = None + self.parse_data(op_data) self.sub_layer = kwargs.get('sub_layer', 0) - self.is_diff = not op_data.get("is_same", True) - self.stack = op_data.get("op_items")[0].get("NPU_Stack_Info") - - @staticmethod - def find_complete_construct(construct_info, op_name): - construct = [op_name] - seen = set(op_name) - while True: - op_name = construct_info.get(op_name) - if not op_name or op_name in seen: - return construct - construct.insert(0, op_name) - seen.add(op_name) - - def find_stack(self, stack_info): - for item in stack_info.values(): + self.is_diff = False + + def find_stack(self): + for item in self.stack.values(): if len(item) >= 2 and self.op_name in item[0]: return item[1] return {} + def parse_data(self, op_data): + self.is_diff = not op_data.get("is_same", True) + self.op_data = op_data.get("op_items") # 这里拿到的是比对column,是一个list,有若干行 + for cmp_data in self.op_data: + name = cmp_data.get("NPU_Name") + metrics = {"NPU max": cmp_data.get("NPU max"), + "NPU min": cmp_data.get("NPU min"), + "NPU mean": cmp_data.get("NPU mean")} + if cmp_data.get("NPU_Stack_Info") != "N/A" and not self.stack: + self.stack = cmp_data.get("NPU_Stack_Info") + if "input" in name: + self.inputs[name] = metrics + elif "output" in name: + self.outputs[name] = metrics + def gen_node_info(self, path: RankPath): - cache = FileCache() - if Const.FORWARD in self.op_name: - data_info_list = {Const.INPUT_ARGS: self.input_args, Const.INPUT_KWARGS: self.input_kwargs, - Const.OUTPUT: self.outputs} - else: - data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} + data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} return {'op_name': self.op_name, 'data_info': data_info_list, 'construct_info': "None", - 'stack_info': self.stack} + 'stack_info': self.find_stack()} class CommunicationNode: @@ -133,44 +129,24 @@ class CommunicationNode: """ tar_api = DiffAnalyseConst.P2P_API_MAPPING.get(self.api, self.api) ranks = set() - for dst in [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP]: - if dst in self.data.op_name: - dst_value = self.data.input_kwargs.get(dst) - if dst_value: - ranks.add(dst_value.get('value')) - break - for src in [DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]: - if src in self.data.input_kwargs: - src_value = self.data.input_kwargs.get(src) - if src_value: - ranks.add(src_value.get('value')) - break - if not ranks: - for item in self.data.input_args: - if isinstance(item, dict) and item.get(Const.TYPE) == 'int': - ranks.add(item.get('value')) - group = self.data.input_kwargs.get('group') - if group: - ranks.update(group.get('group_ranks')) + # 遍历DST和SRC相关的input,获取对应的rank值 + # 遍历inputs获取所有rank值 + for k, v in self.data.inputs: + if any(t in k for t in [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP, + DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]): + if val := v.get("NPU max"): + ranks.add(val) + elif k.endswith('.group'): + ranks.update(list(v.get('NPU max'))) + return {'ranks': ranks, 'api': f'Distributed.{tar_api}', 'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)} def _resolve_type(self): - for src in [DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]: - if src in self.data.input_kwargs and self.data.input_kwargs[src]: - if self.data.input_kwargs[src].get('value') == self.rank: - return DiffAnalyseConst.SRC - else: - return DiffAnalyseConst.DST - for dst in [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP]: - if dst in self.data.input_kwargs and self.data.input_kwargs[dst]: - if self.data.input_kwargs[dst].get('value') == self.rank: - return DiffAnalyseConst.DST - else: - return DiffAnalyseConst.SRC - if self.api in DiffAnalyseConst.DIRECTED_API: - for item in self.data.input_args: - if item.get(Const.TYPE) == 'int': - node_type = DiffAnalyseConst.DIRECTED_API[self.api] - return node_type if item.get('value') == self.rank else DiffAnalyseConst.OPPOSITE_DIR[node_type] + # 遍历SRC和DST相关的输入,根据rank值判断节点类型 + for prefix, node_type in [(DiffAnalyseConst.SRC, DiffAnalyseConst.SRC), + (DiffAnalyseConst.DST, DiffAnalyseConst.DST)]: + for k, v in self.data.inputs: + if prefix in k or f"{prefix}_GROUP" in k: + return node_type if v.get("NPU max") == self.rank else DiffAnalyseConst.OPPOSITE_DIR[node_type] return DiffAnalyseConst.LINK -- Gitee From 34885d95efef612a531078c4506c90e59cea888c Mon Sep 17 00:00:00 2001 From: TAJh Date: Sat, 21 Jun 2025 15:28:23 +0800 Subject: [PATCH 13/18] version 13.0 --- debug/accuracy_tools/msprobe/find_first/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index d20b143b4..740280ff3 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from msprobe.core.common.const import Const from msprobe.core.common.log import logger -from msprobe.find_first.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, DiffAnalyseConst +from msprobe.find_first.utils import RankPath, DiffAnalyseConst @dataclass -- Gitee From 0776d7bcc58b7bcf42f920732532aef624328cad Mon Sep 17 00:00:00 2001 From: TAJh Date: Sat, 21 Jun 2025 18:29:14 +0800 Subject: [PATCH 14/18] version 14.0 --- .../accuracy_tools/msprobe/find_first/analyzer.py | 10 +++++----- debug/accuracy_tools/msprobe/find_first/graph.py | 15 ++++++++++----- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py index fa75e3242..13495fe47 100644 --- a/debug/accuracy_tools/msprobe/find_first/analyzer.py +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -20,7 +20,7 @@ class DiffAnalyzer: self._diff_nodes = [] # 记录所有异常节点 self._cache = FileCache() self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id - self._after_comm_anomalies = {} # 记录各rank下发生在通信节点之后的异常计算节点 + self._after_comm_diffs = {} # 记录各rank下发生在通信节点之后的异常计算节点 self._rank_comm_nodes_dict = {} # 记录各rank的通信节点 def analyze(self): @@ -74,7 +74,7 @@ class DiffAnalyzer: def _post_analyze(self): logger.info('Start searching diff node after communication.') - for nodes in self._after_comm_anomalies.values(): + for nodes in self._after_comm_diffs.values(): if nodes: self._diff_nodes.append(nodes[0]) @@ -91,7 +91,7 @@ class DiffAnalyzer: logger.info(f'Cannot find connected communication node for "{cur_node.node_id}".') def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes): - def connect(): + def connect(search_node): seen_nodes.add(search_node.node_id) if search_node.type == DiffAnalyseConst.DST: cur_node.add_dst(search_node) @@ -114,7 +114,7 @@ class DiffAnalyzer: search_conn_ranks = search_node.find_connected_nodes().get('ranks') if ((not search_conn_ranks and search_node.api not in DiffAnalyseConst.DIRECTED_API) or cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank - connect() + connect(search_node) found = True break return found @@ -146,7 +146,7 @@ class DiffAnalyzer: compute_ops.append(data_node) sub_layer += 1 if compute_ops: - self._after_comm_anomalies[rank] = compute_ops + self._after_comm_diffs[rank] = compute_ops return communication_nodes def _pruning(self): diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index 740280ff3..8897483fe 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -34,6 +34,8 @@ class DataNode: self.op_name = op_name self.rank = rank self.stack = None + self.inputs = None + self.outputs = None self.parse_data(op_data) self.sub_layer = kwargs.get('sub_layer', 0) self.is_diff = False @@ -48,7 +50,7 @@ class DataNode: self.is_diff = not op_data.get("is_same", True) self.op_data = op_data.get("op_items") # 这里拿到的是比对column,是一个list,有若干行 for cmp_data in self.op_data: - name = cmp_data.get("NPU_Name") + name = cmp_data.get("NPU Name") metrics = {"NPU max": cmp_data.get("NPU max"), "NPU min": cmp_data.get("NPU min"), "NPU mean": cmp_data.get("NPU mean")} @@ -131,13 +133,16 @@ class CommunicationNode: ranks = set() # 遍历DST和SRC相关的input,获取对应的rank值 # 遍历inputs获取所有rank值 - for k, v in self.data.inputs: + for k, v in self.data.inputs.items(): if any(t in k for t in [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP, DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]): - if val := v.get("NPU max"): + if val := int(v.get("NPU max")): ranks.add(val) elif k.endswith('.group'): - ranks.update(list(v.get('NPU max'))) + val = v.get('NPU max') + if val and '[]' in val: + val = [int(part) for part in val.strip('[]').split(',')] + ranks.update(val) return {'ranks': ranks, 'api': f'Distributed.{tar_api}', 'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)} @@ -146,7 +151,7 @@ class CommunicationNode: # 遍历SRC和DST相关的输入,根据rank值判断节点类型 for prefix, node_type in [(DiffAnalyseConst.SRC, DiffAnalyseConst.SRC), (DiffAnalyseConst.DST, DiffAnalyseConst.DST)]: - for k, v in self.data.inputs: + for k, v in self.data.inputs.items(): if prefix in k or f"{prefix}_GROUP" in k: return node_type if v.get("NPU max") == self.rank else DiffAnalyseConst.OPPOSITE_DIR[node_type] return DiffAnalyseConst.LINK -- Gitee From 4965600495370d38843bffe903ac722cac21c885 Mon Sep 17 00:00:00 2001 From: TAJh Date: Sat, 21 Jun 2025 18:31:45 +0800 Subject: [PATCH 15/18] bfx --- .../mindspore/monitor/optimizer_collect.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py index 90bb0dd39..77562eab5 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import abstractmethod from mindspore import mint, ops @@ -76,6 +90,9 @@ class OptimizerMon(object): state_param = self.state.get(hp_param, {}) exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None)) exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None)) + if exp_avg is None or exp_avg_sq is None: + logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.") + continue if monitor.mv_distribution: exp_avg_dict[name] = exp_avg exp_avg_sq_dict[name] = exp_avg_sq -- Gitee From 30735c42b75892f1f5ba79266d62120c94fa5e9c Mon Sep 17 00:00:00 2001 From: TAJh Date: Sat, 21 Jun 2025 18:32:30 +0800 Subject: [PATCH 16/18] bfx --- .../msprobe/mindspore/monitor/optimizer_collect.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py index 77562eab5..7efbb4590 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py @@ -90,9 +90,6 @@ class OptimizerMon(object): state_param = self.state.get(hp_param, {}) exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None)) exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None)) - if exp_avg is None or exp_avg_sq is None: - logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.") - continue if monitor.mv_distribution: exp_avg_dict[name] = exp_avg exp_avg_sq_dict[name] = exp_avg_sq @@ -108,6 +105,9 @@ class OptimizerMon(object): else: logger.warning(f"step of {name} is None, maybe something wrong happened.") continue + if exp_avg is None or exp_avg_sq is None: + logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.") + continue exp_avg_hat = exp_avg / (1 - self.optim.defaults['betas'][0] ** step) exp_avg_sq_hat = exp_avg_sq / (1 - self.optim.defaults['betas'][1] ** step) update_dict[name] = exp_avg_hat / (mint.sqrt(exp_avg_sq_hat) + self.optim.defaults['eps']) -- Gitee From c787dafb52a00b98faf0208cee90ac956425f85d Mon Sep 17 00:00:00 2001 From: TAJh Date: Mon, 23 Jun 2025 10:41:13 +0800 Subject: [PATCH 17/18] version 16.0 --- debug/accuracy_tools/msprobe/find_first/graph.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index 8897483fe..3bde748a0 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -34,14 +34,14 @@ class DataNode: self.op_name = op_name self.rank = rank self.stack = None - self.inputs = None - self.outputs = None + self.inputs = {} + self.outputs = {} + self.is_diff = False self.parse_data(op_data) self.sub_layer = kwargs.get('sub_layer', 0) - self.is_diff = False def find_stack(self): - for item in self.stack.values(): + for item in self.stack: if len(item) >= 2 and self.op_name in item[0]: return item[1] return {} @@ -66,7 +66,7 @@ class DataNode: return {'op_name': self.op_name, 'data_info': data_info_list, 'construct_info': "None", - 'stack_info': self.find_stack()} + 'stack_info': self.stack} class CommunicationNode: -- Gitee From a4722063c6fb98f0eaf65820c4652fbabff80126 Mon Sep 17 00:00:00 2001 From: TAJh Date: Mon, 23 Jun 2025 15:17:54 +0800 Subject: [PATCH 18/18] 123 --- .../msprobe/find_first/analyzer.py | 15 ++++++ .../msprobe/find_first/graph.py | 50 +++++++++++-------- .../msprobe/find_first/utils.py | 17 +++++++ 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/debug/accuracy_tools/msprobe/find_first/analyzer.py b/debug/accuracy_tools/msprobe/find_first/analyzer.py index 13495fe47..888d97b36 100644 --- a/debug/accuracy_tools/msprobe/find_first/analyzer.py +++ b/debug/accuracy_tools/msprobe/find_first/analyzer.py @@ -1,3 +1,18 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from collections import defaultdict import os diff --git a/debug/accuracy_tools/msprobe/find_first/graph.py b/debug/accuracy_tools/msprobe/find_first/graph.py index 3bde748a0..cb6cb7f69 100644 --- a/debug/accuracy_tools/msprobe/find_first/graph.py +++ b/debug/accuracy_tools/msprobe/find_first/graph.py @@ -16,6 +16,7 @@ from dataclasses import dataclass from msprobe.core.common.const import Const from msprobe.core.common.log import logger +from msprobe.core.common.const import CompareConst from msprobe.find_first.utils import RankPath, DiffAnalyseConst @@ -50,12 +51,13 @@ class DataNode: self.is_diff = not op_data.get("is_same", True) self.op_data = op_data.get("op_items") # 这里拿到的是比对column,是一个list,有若干行 for cmp_data in self.op_data: - name = cmp_data.get("NPU Name") - metrics = {"NPU max": cmp_data.get("NPU max"), - "NPU min": cmp_data.get("NPU min"), - "NPU mean": cmp_data.get("NPU mean")} - if cmp_data.get("NPU_Stack_Info") != "N/A" and not self.stack: - self.stack = cmp_data.get("NPU_Stack_Info") + name = cmp_data.get(CompareConst.NPU_NAME) + metrics = {CompareConst.NPU_MAX: cmp_data.get(CompareConst.NPU_MAX), + CompareConst.NPU_MIN: cmp_data.get(CompareConst.NPU_MIN), + CompareConst.NPU_MEAN: cmp_data.get(CompareConst.NPU_MEAN), + CompareConst.NPU_NORM: cmp_data.get(CompareConst.NPU_NORM)} + if cmp_data.get(CompareConst.STACK) != "N/A" and not self.stack: + self.stack = cmp_data.get(CompareConst.STACK) if "input" in name: self.inputs[name] = metrics elif "output" in name: @@ -65,7 +67,6 @@ class DataNode: data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} return {'op_name': self.op_name, 'data_info': data_info_list, - 'construct_info': "None", 'stack_info': self.stack} @@ -117,13 +118,17 @@ class CommunicationNode: for node in self.next_nodes.values(): node.pre_node = None for node in self.dst_nodes.values(): - node.src_nodes.pop(self.node_id) + if node.src_nodes: + node.src_nodes.pop(self.node_id) for node in self.src_nodes.values(): - node.dst_nodes.pop(self.node_id) + if node.dst_nodes: + node.dst_nodes.pop(self.node_id) for node in self.link_nodes.values(): - node.link_nodes.pop(self.node_id) + if node.link_nodes: + node.link_nodes.pop(self.node_id) if self.pre_node: - self.pre_node.next_nodes.pop(self.node_id) + if self.pre_node.next_nodes: + self.pre_node.next_nodes.pop(self.node_id) def find_connected_nodes(self): """ @@ -133,14 +138,18 @@ class CommunicationNode: ranks = set() # 遍历DST和SRC相关的input,获取对应的rank值 # 遍历inputs获取所有rank值 - for k, v in self.data.inputs.items(): - if any(t in k for t in [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP, - DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]): - if val := int(v.get("NPU max")): - ranks.add(val) - elif k.endswith('.group'): - val = v.get('NPU max') - if val and '[]' in val: + for input_name, v in self.data.inputs.items(): + # 检查key是否包含DST/SRC相关标识 + target_types = [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP, + DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP] + if any(keyword in input_name for keyword in target_types): + # 获取NPU_MAX值并转为整数 + rank_val = int(v.get(CompareConst.NPU_MAX, 0)) + if rank_val: + ranks.add(rank_val) + elif input_name.endswith('.group'): + val = v.get(CompareConst.NPU_MAX) + if val and val.startswith('[') and val.endswith(']'): val = [int(part) for part in val.strip('[]').split(',')] ranks.update(val) @@ -153,5 +162,6 @@ class CommunicationNode: (DiffAnalyseConst.DST, DiffAnalyseConst.DST)]: for k, v in self.data.inputs.items(): if prefix in k or f"{prefix}_GROUP" in k: - return node_type if v.get("NPU max") == self.rank else DiffAnalyseConst.OPPOSITE_DIR[node_type] + return node_type if v.get(CompareConst.NPU_MAX) == self.rank \ + else DiffAnalyseConst.OPPOSITE_DIR[node_type] return DiffAnalyseConst.LINK diff --git a/debug/accuracy_tools/msprobe/find_first/utils.py b/debug/accuracy_tools/msprobe/find_first/utils.py index 0e8c30760..58729c118 100644 --- a/debug/accuracy_tools/msprobe/find_first/utils.py +++ b/debug/accuracy_tools/msprobe/find_first/utils.py @@ -1,3 +1,18 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import OrderedDict from dataclasses import dataclass import sys @@ -6,6 +21,7 @@ import psutil from msprobe.core.common.file_utils import check_file_or_directory_path, load_json + @dataclass class RankPath: rank: int @@ -18,6 +34,7 @@ class RankPath: check_file_or_directory_path(dump_path) self.dump_path = dump_path + class FileCache: """ lazy load file -- Gitee