diff --git a/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py b/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py index e147f23b7c7bd514a13251830e0365928876bc75..a32ad7c351cfcaf33eb2b3e6fb78d30994b1015a 100644 --- a/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py +++ b/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py @@ -221,7 +221,7 @@ class NanAnalyzer: node = get_next_node(nodes) if not node: continue - if not groups or node.node_id in all_ids_in_groups: + if not groups or node.node_id not in all_ids_in_groups: new_group = find_all_members(node) groups.append(new_group) all_ids_in_groups.update(new_group) diff --git a/debug/accuracy_tools/msprobe/nan_analyze/graph.py b/debug/accuracy_tools/msprobe/nan_analyze/graph.py index 5a4f8fb87296a39796b5124854ba7060be71d53a..979e9e296aef998959eb8c1be6041c925e80683e 100644 --- a/debug/accuracy_tools/msprobe/nan_analyze/graph.py +++ b/debug/accuracy_tools/msprobe/nan_analyze/graph.py @@ -99,13 +99,13 @@ class CommunicationNode: self.link_nodes = kwargs.get('link_nodes', {}) self.dst_nodes = kwargs.get('dst_nodes', {}) self.src_nodes = kwargs.get('src_nodes', {}) - self.next_nodes = kwargs.get('next_nodes', {}) + self.next_node = kwargs.get('next_node') self.compute_ops = kwargs.get('compute_ops', []) self.type = self._resolve_type() self.connected = False def add_next(self, node): - self.next_nodes[node.node_id] = node + self.next_node = node node.pre_node = self node.layer = self.layer + 1 node.data.layer = node.layer @@ -113,7 +113,9 @@ class CommunicationNode: def add_link(self, node): self.link_nodes[node.node_id] = node node.link_nodes[self.node_id] = self - node.layer = self.layer + layer = max(node.layer, self.layer) + self.update_layer(layer) + node.update_layer(layer) node.data.layer = node.layer self.connected = True node.connected = True @@ -121,14 +123,16 @@ class CommunicationNode: def add_dst(self, node): self.dst_nodes[node.node_id] = node node.src_nodes[self.node_id] = self - node.layer = self.layer + layer = max(node.layer, self.layer) + self.update_layer(layer) + node.update_layer(layer) node.data.layer = node.layer self.connected = True node.connected = True def delete(self): - for node in self.next_nodes.values(): - node.pre_node = None + if self.next_node: + self.next_node.pre_node = None for node in self.dst_nodes.values(): node.src_nodes.pop(self.node_id) for node in self.src_nodes.values(): @@ -136,11 +140,29 @@ class CommunicationNode: for node in self.link_nodes.values(): node.link_nodes.pop(self.node_id) if self.pre_node: - self.pre_node.next_nodes.pop(self.node_id) + self.pre_node.next_node = None + + def update_layer(self, layer): + if layer == self.layer: + return + + def update_comm_layer(node): + nodes = set(node.src_nodes.values()) | set(node.dst_nodes.values()) | set(node.link_nodes.values()) + for comm_node in nodes: + comm_node.update_layer(layer) + + self.layer = layer + update_comm_layer(self) + next_node = self.next_node + while next_node: + layer += 1 + next_node.layer = layer + update_comm_layer(next_node) + next_node = next_node.next_node def has_nan_inf(self): return self.input_has_nan_inf() or check_item_anomaly(self.data.outputs) - + def input_has_nan_inf(self): return check_item_anomaly(self.data.input_args) or check_item_anomaly(self.data.input_kwargs) diff --git a/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py index 3c1c43b3b109d4d232622cb581356e29f4eb83b8..50dc6cba83802e5fc45f6fca6c2815132b361f6d 100644 --- a/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py +++ b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py @@ -37,7 +37,7 @@ class TestCommunicationNode(unittest.TestCase): comm_node_0.add_next(comm_node_1) self.assertEqual(comm_node_0.layer + 1, comm_node_1.layer) self.assertTrue(comm_node_0 is comm_node_1.pre_node) - self.assertTrue(comm_node_1.node_id in comm_node_0.next_nodes) + self.assertTrue(comm_node_0.next_node is comm_node_1) def test_add_link(self): op_name = 'Distributed.all_gather.0.forward' @@ -67,7 +67,7 @@ class TestCommunicationNode(unittest.TestCase): comm_node_0.add_dst(comm_node_1) comm_node_0.delete() self.assertFalse(comm_node_1.src_nodes) - self.assertFalse(comm_node_2.next_nodes) + self.assertFalse(comm_node_2.next_node) def test_has_nan_inf(self): op_name = 'Distributed.broadcast.0.forward' diff --git a/debug/accuracy_tools/msprobe/visualization/graph_service.py b/debug/accuracy_tools/msprobe/visualization/graph_service.py index b14ccab0386be92c0cdce7ebc89854a9ce17aa92..a9f7870beabd04658b395e3349dd0be416af27f5 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/visualization/graph_service.py @@ -66,7 +66,7 @@ def _compare_graph_result(input_param, args): # 对两个数据进行构图 graph_n = _build_graph_info(input_param.get('npu_path'), args) graph_b = _build_graph_info(input_param.get('bench_path'), args) - logger.info('Model graphs built successfully, start Comparing graphs...') + logger.info('Model graphs built successfully, start comparing graphs...') # 基于graph、stack和data进行比较 graph_comparator = _compare_graph(graph_n, graph_b, input_param, args) # 增加micro step标记