From e440b7659a4caa06bcd72867da18103fd3537679 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Thu, 24 Jul 2025 16:52:21 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Feature=E3=80=91=E5=88=86=E7=BA=A7?= =?UTF-8?q?=E5=8F=AF=E8=A7=86=E5=8C=96=E6=94=AF=E6=8C=81tp=20pp=E5=88=87?= =?UTF-8?q?=E5=88=86=E5=90=88=E5=B9=B6-part2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/docs/21.visualization_PyTorch.md | 43 +- .../docs/22.visualization_MindSpore.md | 42 +- .../builder/test_graph_merger.py | 409 ++++++++++++++++++ .../builder/test_msprobe_adapter.py | 9 +- .../visualization_ut/test_graph_service.py | 2 + .../visualization/builder/msprobe_adapter.py | 20 +- .../visualization/compare/graph_comparator.py | 42 +- .../msprobe/visualization/graph/base_node.py | 4 + .../msprobe/visualization/graph/graph.py | 6 +- .../msprobe/visualization/graph_service.py | 105 ++++- 10 files changed, 635 insertions(+), 47 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py diff --git a/debug/accuracy_tools/msprobe/docs/21.visualization_PyTorch.md b/debug/accuracy_tools/msprobe/docs/21.visualization_PyTorch.md index f5e81f2d8c..7bc34a88d6 100644 --- a/debug/accuracy_tools/msprobe/docs/21.visualization_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/21.visualization_PyTorch.md @@ -90,7 +90,7 @@ msprobe -f pytorch graph -i ./compare.json -o ./output | npu_path | 指定待调试侧比对路径,str类型。工具根据路径格式自动进行单rank比对、多rank批量比对或多step批量比对,具体格式参考3.2 图构建和比对。 | 是 | | bench_path | 指定标杆侧比对路径,str类型。单图构建场景可以不配置。 | 否 | | is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | - +| parallel_merge | 配置是否开启不同切分策略下的图合并,dict类型。rank_size、tp、pp参数按实际情况进行配置。比对时配置npu、bench,只构图配置npu。 配置示例见[3.2.5 不同切分策略下的图合并](#325-不同切分策略下的图合并)。 | 否 | ### 3.2 图构建和比对 @@ -315,6 +315,47 @@ dump配置请参考[dump配置示例](./03.config_examples.md#16-task-配置为- 得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。 +#### 3.2.5 不同切分策略下的图合并 + +适用场景:不同模型并行切分策略下,两个模型产生了精度差异,需要进行整网数据比对,但被切分的数据或模型结构分布于多rank中无法进行比对,需要将分布在各个rank的数据或模型结构合并后再进行比对。 + +使用限制: + +- 当前支持的模型并行切分策略:Tensor Parallelism(TP)、Pipeline Parallelism(PP)。 +- 当前支持基于Megatron、MindSpeed-LLM套件的模型进行图合并,其他套件的模型图合并效果有待验证; +- 当前仅支持msprobe工具dump的statistics数据; +- 图合并比对时要确保Data Parallelism(DP)切分一致,例如rank=8 tp=1 pp=8的配置,dp=1,图合并将得到一张图,rank=8 tp=1 pp=4的配置,dp=2,图合并将得到两张图,暂不支持数量不一致的图进行比对。 + +使能方式: + +在compare.json里增加parallel_merge配置项, rank_size、tp、pp和vpp参数按实际情况进行配置。 + +参数说明: + +所需tp、pp和vpp参数来自于Megatron、MindSpeed-LLM套件中的训练脚本实际配置。 + +| 参数名 | 说明 | 是否必填 | +|-----------|--------------------------------------------------------------------------------------------------------------------------|------| +| rank_size | 模型实际训练所用加速卡的数量,int类型。`rank_size=tp*pp*cp*dp`,由于暂不支持CP合并,图合并功能中默认cp=1。 | 是 | +| tp | 张量并行大小,int类型。实际训练脚本中需指定`--tensor-model-parallel-size T`,其中`T`表示张量模型并行大小,即**图合并所需的参数tp**, `tp=T`。 | 是 | +| pp | 流水线并行的阶段数,int类型。实际训练脚本中需指定`--pipeline-model-parallel-size P`,其中`P`表示流水线并行的阶段数,即**图合并所需的参数pp**, `pp=P`。 | 是 | + +npu_path、bench_path的配置以及执行命令请参考[3.2.3 批量构建或比对](#323-批量构建或比对) + +如果只进行图构建,"bench_path"和"parallel_merge"中的"bench"参数可不配置。 + +``` +{ + "npu_path": "./npu_dump", + "bench_path": "./bench_dump", + "is_print_compare_log": true, + "parallel_merge": { + "npu": {"rank_size": 8, "tp": 8, "pp": 1}, + "bench": {"rank_size": 8, "tp": 1, "pp": 8} + } +} +``` + ## 4.启动tensorboard ### 4.1 可直连的服务器 diff --git a/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md b/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md index c621ede001..6780774b9c 100644 --- a/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md @@ -91,7 +91,7 @@ msprobe -f mindspore graph -i ./compare.json -o ./output | npu_path | 指定待调试侧比对路径,str类型。工具根据路径格式自动进行单rank比对、多rank批量比对或多step批量比对,具体格式参考3.2 图构建和比对。 | 是 | | bench_path | 指定标杆侧比对路径,str类型。单图构建场景可以不配置。 | 否 | | is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | - +| parallel_merge | 配置是否开启不同切分策略下的图合并,dict类型。rank_size、tp、pp参数按实际情况进行配置。比对时配置npu、bench,只构图配置npu。 配置示例见[3.2.5 不同切分策略下的图合并](#325-不同切分策略下的图合并)。 | 否 | ### 3.2 图构建和比对 @@ -316,6 +316,46 @@ dump配置请参考[dump配置示例](./03.config_examples.md#35-task-配置为- 得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。 +#### 3.2.5 不同切分策略下的图合并 + +适用场景:不同模型并行切分策略下,两个模型产生了精度差异,需要进行整网数据比对,但被切分的数据或模型结构分布于多rank中无法进行比对,需要将分布在各个rank的数据或模型结构合并后再进行比对。 + +使用限制: + +- 当前支持的模型并行切分策略:Tensor Parallelism(TP)、Pipeline Parallelism(PP)。 +- 当前支持基于Megatron、MindSpeed-LLM套件的模型进行图合并,其他套件的模型图合并效果有待验证; +- 当前仅支持msprobe工具dump的statistics数据; +- 图合并比对时要确保Data Parallelism(DP)切分一致,例如rank=8 tp=1 pp=8的配置,dp=1,图合并将得到一张图,rank=8 tp=1 pp=4的配置,dp=2,图合并将得到两张图,暂不支持数量不一致的图进行比对。 + +使能方式: + +在compare.json里增加parallel_merge配置项, rank_size、tp、pp和vpp参数按实际情况进行配置。 + +参数说明: + +所需tp、pp和vpp参数来自于Megatron、MindSpeed-LLM套件中的训练脚本实际配置。 + +| 参数名 | 说明 | 是否必填 | +|-----------|--------------------------------------------------------------------------------------------------------------------------|------| +| rank_size | 模型实际训练所用加速卡的数量,int类型。`rank_size=tp*pp*cp*dp`,由于暂不支持CP合并,图合并功能中默认cp=1。 | 是 | +| tp | 张量并行大小,int类型。实际训练脚本中需指定`--tensor-model-parallel-size T`,其中`T`表示张量模型并行大小,即**图合并所需的参数tp**, `tp=T`。 | 是 | +| pp | 流水线并行的阶段数,int类型。实际训练脚本中需指定`--pipeline-model-parallel-size P`,其中`P`表示流水线并行的阶段数,即**图合并所需的参数pp**, `pp=P`。 | 是 | + +npu_path、bench_path的配置以及执行命令请参考[3.2.3 批量构建或比对](#323-批量构建或比对) + +如果只进行图构建,"bench_path"和"parallel_merge"中的"bench"参数可不配置。 + +``` +{ + "npu_path": "./npu_dump", + "bench_path": "./bench_dump", + "is_print_compare_log": true, + "parallel_merge": { + "npu": {"rank_size": 8, "tp": 8, "pp": 1}, + "bench": {"rank_size": 8, "tp": 1, "pp": 8} + } +} +``` ## 4.启动TensorBoard diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py new file mode 100644 index 0000000000..6ad0de9fcd --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py @@ -0,0 +1,409 @@ +import unittest +from unittest.mock import patch, MagicMock, call +from msprobe.visualization.builder.graph_merger import ( + GraphMerger, BaseGraphMerger, PPMerger, TPMerger, + NoParallelMerger, TPPPMerger, FullMerger +) +from msprobe.core.common.const import Const +from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.node_op import NodeOp +from msprobe.visualization.graph.graph import Graph +from msprobe.core.common.exceptions import MsprobeException + + +class TestGraphMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = MagicMock() + self.parallel_param = MagicMock(tp=1, pp=1, rank_size=1) + self.is_bench = False + + def test_select_strategy_no_parallel(self): + self.parallel_param.tp = self.parallel_param.pp = self.parallel_param.rank_size = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, NoParallelMerger) + + def test_select_strategy_tp(self): + self.parallel_param.tp = self.parallel_param.rank_size = 2 + self.parallel_param.pp = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, TPMerger) + + def test_select_strategy_pp(self): + self.parallel_param.pp = self.parallel_param.rank_size = 2 + self.parallel_param.tp = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, PPMerger) + + def test_select_strategy_tp_pp(self): + self.parallel_param.tp = self.parallel_param.pp = 2 + self.parallel_param.rank_size = 4 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, TPPPMerger) + + def test_select_strategy_full(self): + self.parallel_param.tp = 2 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, FullMerger) + + def test_merge_graph(self): + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + merger.strategy.merge_graphs = MagicMock() + merger.merge_graph() + merger.strategy.merge_graphs.assert_called_once() + + +class TestBaseGraphMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(2)] + self.parallel_param = MagicMock(tp=1, pp=1, rank_size=2) + self.is_bench = False + self.merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_sort_merged_api_collection(self): + graph = MagicMock() + root = MagicMock() + graph.root = root + subnode1 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS}.0", op=NodeOp.api_collection) + subnode1.subnodes = [MagicMock(id="op_Rank1.0"), MagicMock(id="op_Rank0.0")] + root.subnodes = [subnode1] + self.merger.sort_merged_api_collection(graph) + self.assertEqual([n.id for n in subnode1.subnodes], ["op_Rank0.0", "op_Rank1.0"]) + + def test_update_node_data_key(self): + data_dict = { + "old_id.input.0": {"full_op_name": "old_id.op"}, + "other_key": {"value": "test"} + } + new_dict = self.merger._update_node_data_key("old_id", "new_id", data_dict) + self.assertEqual(new_dict, { + "new_id.input.0": {"full_op_name": "new_id.op"}, + "other_key": {"value": "test"} + }) + + def test_compare_value_same(self): + self.assertTrue(self.merger._compare_value_same(1, 1)) + self.assertFalse(self.merger._compare_value_same(1, 2)) + self.assertTrue(self.merger._compare_value_same("a", "a")) + self.assertTrue(self.merger._compare_value_same(1, 1.00000001, has_uncertainty=True)) + self.assertFalse(self.merger._compare_value_same(1, 1.1, has_uncertainty=True)) + + def test_merge_graph_api_collection(self): + results = [MagicMock() for _ in range(2)] + graph0, graph1 = Graph("name1"), Graph("name2") + results[0].graph, results[1].graph = graph0, graph1 + root0, root1 = MagicMock(), MagicMock() + graph0.root, graph1.root = root0, root1 + node0 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES}.0") + node0_sub1 = MagicMock(id="sub_op.0") + node0.subnodes = [node0_sub1] + node1 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES}.0") + node1_sub1 = MagicMock(id="sub_op.0") + graph0.node_map = {f"{GraphConst.APIS_BETWEEN_MODULES}.0": node0} + node1.subnodes = [node1_sub1] + root0.subnodes = [node0] + root1.subnodes = [node1] + + self.merger.merge_graph_api_collection(results) + + self.assertEqual(len(root0.subnodes), 1) + self.assertTrue(root0.subnodes[0].id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS)) + self.assertEqual(len(root0.subnodes[0].subnodes), 1) + + def test_split_graph_results_by_groups(self): + groups = [[0, 1], [2, 3]] + results = [MagicMock(rank=i) for i in range(4)] + self.merger.build_graph_results = results + split = self.merger.split_graph_results_by_groups(groups) + self.assertEqual(len(split), 2) + self.assertEqual([r.rank for r in split[0]], [0, 1]) + self.assertEqual([r.rank for r in split[1]], [2, 3]) + + def test_compare_node_param_data(self): + main_node = MagicMock() + other_nodes = [MagicMock()] + main_node.id = "id" + other_nodes[0].id = "id" + main_node.input_data = {"input.0": {Const.DTYPE: "torch.float16", Const.MAX: 1}} + other_nodes[0].input_data = {"input.0": {Const.DTYPE: "torch.float16", Const.MAX: 2}} + in_diff, out_diff = self.merger.compare_node_param_data(main_node, other_nodes) + self.assertEqual(list(in_diff.keys()), ["input.0"]) + + def test_compare_param_same(self): + param1 = {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1} + param2 = {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1} + self.assertTrue(self.merger.compare_param_same(param1, param2)) + + param2[Const.MAX] = 2 + self.assertFalse(self.merger.compare_param_same(param1, param2)) + + def test_add_all_nodes_rank(self): + graph0, graph1 = MagicMock(), MagicMock() + node0, node1 = MagicMock(), MagicMock() + graph0.node_map.values.return_value = [node0] + graph1.node_map.values.return_value = [node1] + self.build_graph_results[0].graph = graph0 + self.build_graph_results[1].graph = graph1 + + self.merger._add_all_nodes_rank() + + self.assertEqual(node0.rank, 0) + self.assertEqual(node1.rank, 1) + + def test_get_default_groups(self): + self.parallel_param.tp = 4 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + tp_groups, pp_groups = merger.get_default_groups() + self.assertEqual(tp_groups, [[0, 1, 2, 3], [4, 5, 6, 7]]) + self.assertEqual(pp_groups, [[0, 4], [1, 5], [2, 6], [3, 7]]) + + self.parallel_param.tp = 2 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + tp_groups, pp_groups = merger.get_default_groups() + self.assertEqual(tp_groups, [[0, 1], [2, 3], [4, 5], [6, 7]]) + self.assertEqual(pp_groups, [[0, 2], [1, 3], [4, 6], [5, 7]]) + + self.parallel_param.tp = 2 + self.parallel_param.pp = 3 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + with self.assertRaises(MsprobeException): + merger.get_default_groups() + + +class TestPPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = MagicMock(tp=1, pp=4, rank_size=4) + self.is_bench = False + self.merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_trace_p2p_mapping(self): + p2p_mapping = {0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 4, 7: 5} + chains = self.merger._trace_p2p_mapping(p2p_mapping) + self.assertEqual(len(chains), 2) + self.assertIn([0, 2, 4, 6], chains) + self.assertIn([1, 3, 5, 7], chains) + + @patch('msprobe.visualization.builder.graph_merger.PPMerger._merge_nodes') + def test_merge_nodes(self, mock_merge): + main_graph = MagicMock() + main_node = MagicMock(id="module.layers.0.forward") + other_graphs = [MagicMock() for _ in range(3)] + for i, g in enumerate(other_graphs): + g.get_node.return_value = MagicMock(id=f"module.layers.{i}.forward") + + self.merger._merge_nodes(main_graph, main_node, other_graphs) + mock_merge.assert_called() + + def test_merge_graphs(self): + self.merger.get_groups = MagicMock(return_value=[[0, 1, 2, 3]]) + self.merger.merge_pp_graphs = MagicMock(return_value=self.build_graph_results[:1]) + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + def test_get_groups(self): + for i, result in enumerate(self.build_graph_results): + graph = MagicMock() + node = MagicMock(id=f"Distributed.send.{i}.forward") + node.input_data = {f"Distributed.send.{i}.forward.input.dst": {"value": (i + 1) % 4}} + graph.node_map.values.return_value = [node] + result.graph = graph + + groups = self.merger.get_groups() + self.assertEqual(len(groups), 1) + self.assertEqual(groups[0], [0, 1, 2, 3]) + + def test_merge_other_unique_nodes(self): + main_graph = MagicMock() + main_node = MagicMock() + other_nodes = [MagicMock()] + main_node.subnodes = [MagicMock(id="main_sub.0")] + other_nodes[0].subnodes = [MagicMock(id="other_sub.0")] + + self.merger._merge_other_unique_nodes(main_graph, main_node, other_nodes) + self.assertEqual(len(main_node.subnodes), 2) + + def test_sort_nodes(self): + graph = MagicMock() + start_node = MagicMock(id="module.layers.0.forward%0%0") + start_node.op = NodeOp.module + api_node = MagicMock(id="Torch.mul.forward.0%0%0") + graph.node_map = {"module.layers.0.forward%0%0": start_node, "Torch.mul.forward.0%0%0": api_node} + parent_node = MagicMock() + parent_node.subnodes = [start_node, api_node] + start_node.upnode = parent_node + + self.merger._sort_nodes(graph, start_node) + self.assertEqual(parent_node.subnodes[0].id, "module.layers.0.forward") + self.assertEqual(parent_node.subnodes[1].id, "Torch.mul_rank0.forward.0") + + def test_add_node_to_main_graph(self): + graph = MagicMock() + node = MagicMock() + subnode = MagicMock() + node.subnodes = [subnode] + + self.merger._add_node_to_main_graph(graph, node) + graph.node_map.__setitem__.assert_has_calls([call(node.id, node), call(subnode.id, subnode)]) + + def test_get_node_sort_rule(self): + node = MagicMock(id="module.layers.0.forward%1%2") + self.assertEqual(self.merger._get_node_sort_rule(node), (2, 1)) + self.assertEqual(self.merger._get_node_sort_rule(node, rank_ascending=False), (-2, 1)) + + def test_mark_node_id_position_rank(self): + node = MagicMock() + parent_node = MagicMock() + parent_node.subnodes = [MagicMock(), node, MagicMock()] + node.upnode = parent_node + node.id = "module.layers.0.forward" + + self.merger._mark_node_id_position_rank(node, 2) + self.assertEqual(node.id, "module.layers.0.forward%1%2") + + def test_update_node_id(self): + graph = MagicMock() + start_node = MagicMock(id="module.layers.0.forward%1%2") + start_node.op = NodeOp.module + start_node.pp_index = 1 + graph.node_map = {start_node.id: start_node} + + self.merger._update_node_id(graph, start_node) + self.assertEqual(start_node.id, "module.layers.1.forward") + + +class TestTPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = MagicMock(tp=4, pp=1, rank_size=4) + self.is_bench = False + self.merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_merge_params(self): + params = { + "input.0": [ + {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1}, + {Const.MAX: 2, Const.MIN: 0, Const.MEAN: 0.7, Const.NORM: 1.2} + ] + } + merge_info = self.merger._merge_params(params) + self.assertIn("The Max value merging method for input.0 is: max(1, 2) = 2", merge_info) + self.assertIn("The Mean value merging method for input.0 is: (0.5 + 0.7) / 2 = 0.6", merge_info) + + def test_get_need_merge_node(self): + main_node = MagicMock(id="module.matmul_rank0.forward") + other_graphs = [MagicMock() for _ in range(3)] + tp_merge_mapping = {0: [1, 2, 3]} + + for i, g in enumerate(other_graphs): + g.node_map = {f"module.matmul_rank{i + 1}.forward": MagicMock()} + + nodes = self.merger._get_need_merge_node(main_node, other_graphs, tp_merge_mapping) + self.assertEqual(len(nodes), 0) + + def test_merge_graphs(self): + self.merger.get_groups = MagicMock(return_value=[[0, 1, 2, 3]]) + self.merger.merge_tp_graphs = MagicMock(return_value=self.build_graph_results[:1]) + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + def test_get_groups(self): + for i, result in enumerate(self.build_graph_results): + graph = MagicMock() + node = MagicMock(id=f"all_reduce.{i}") + node.input_data = {f"all_reduce.{i}.input.group": {"group_ranks": [0, 1, 2, 3]}} + graph.node_map.values.return_value = [node] + result.graph = graph + + groups = self.merger.get_groups() + self.assertEqual(len(groups), 1) + self.assertEqual(groups[0], [0, 1, 2, 3]) + + def test_handle_tp_matmul_reduce(self): + node = MagicMock(id=f"module.RowParallelLinear.forward.0") + node.op = NodeOp.module + matmul_node = MagicMock(id="matmul.0") + matmul_node.output_data = {"output.0": {Const.MAX: 1}} + reduce_node = MagicMock(id="all_reduce.0") + reduce_node.input_data = {"input.0": {Const.MAX: 1}} + reduce_node.output_data = {"output.0": {Const.MAX: 2}} + node.subnodes = [matmul_node, reduce_node] + other_graphs = [MagicMock()] + + self.merger._handle_tp_matmul_reduce(node, other_graphs, {}) + self.assertEqual(matmul_node.output_data["output.0"][Const.MAX], 2) + + +class TestNoParallelMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock()] + self.parallel_param = MagicMock(tp=1, pp=1, rank_size=1) + self.is_bench = False + self.merger = NoParallelMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_merge_graphs(self): + self.merger.merge_graph_api_collection = MagicMock() + results = self.merger.merge_graphs() + self.assertEqual(results, self.build_graph_results) + self.merger.merge_graph_api_collection.assert_called_once_with(self.build_graph_results) + + +class TestTPPPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = MagicMock(tp=2, pp=2, rank_size=4) + self.is_bench = False + self.merger = TPPPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + @patch('msprobe.visualization.builder.graph_merger.TPMerger') + @patch('msprobe.visualization.builder.graph_merger.PPMerger') + def test_merge_graphs(self, mock_pp, mock_tp): + tp_merger = MagicMock() + pp_merger = MagicMock() + mock_tp.return_value = tp_merger + mock_pp.return_value = pp_merger + + pp_merger.get_groups.return_value = [[0, 1], [2, 3]] + tp_merger.get_groups.return_value = [[0, 2], [1, 3]] + tp_merger.merge_tp_graphs.return_value = [MagicMock()] + + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + +class TestFullMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(8)] + self.parallel_param = MagicMock(tp=2, pp=4, rank_size=8) + self.is_bench = False + self.merger = FullMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + @patch('msprobe.visualization.builder.graph_merger.TPMerger') + @patch('msprobe.visualization.builder.graph_merger.PPMerger') + def test_merge_graphs(self, mock_pp, mock_tp): + tp_merger = MagicMock() + pp_merger = MagicMock() + mock_tp.return_value = tp_merger + mock_pp.return_value = pp_merger + + pp_merger.get_groups.return_value = [[0, 1, 2, 3], [4, 5, 6, 7]] + tp_merger.get_groups.return_value = [[0, 4], [1, 5], [2, 6], [3, 7]] + + pp_result0 = MagicMock(rank=0) + pp_result1 = MagicMock(rank=4) + pp_merger.merge_pp_graphs.side_effect = [[pp_result0], [pp_result1]] + + tp_merger.merge_tp_graphs.side_effect = [[MagicMock()], [MagicMock()]] + + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py index bee32a34a0..e2ca516542 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py @@ -11,6 +11,7 @@ from msprobe.visualization.builder.msprobe_adapter import ( _format_data ) from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.base_node import BaseNode import torch from msprobe.core.common.const import Const @@ -55,11 +56,9 @@ class TestMsprobeAdapter(unittest.TestCase): @patch('msprobe.visualization.builder.msprobe_adapter.get_accuracy') def test_compare_node(self, mock_get_accuracy): - node_ids = ["node1", "node2"] - data_dicts = [{'node1': {"input_args": [], "input_kwargs": {}, "output": {}}}, - {'node2': {"input_args": [], "input_kwargs": {}, "output": {}}}] - stack_json_data = {} - result = compare_node(node_ids, data_dicts, stack_json_data, GraphConst.REAL_DATA_COMPARE) + node_n = BaseNode('', 'node1') + node_b = BaseNode('', 'node2') + result = compare_node(node_n, node_b, GraphConst.REAL_DATA_COMPARE) mock_get_accuracy.assert_called_once() self.assertIsInstance(result, list) diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py index f9ca5592aa..af988fc01e 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py @@ -21,6 +21,8 @@ class Args: overflow_check: bool = False fuzzy_match: bool = False complete_stack: bool = False + parallel_merge: bool = False + parallel_params: tuple = None class TestGraphService(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py index cc304c8aa7..6bf6d1ab04 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py @@ -146,31 +146,27 @@ def format_node_data(data_dict, node_id=None, compare_mode=None): return data_dict -def compare_node(node_ids, data_dicts, stack_json_data, compare_mode): +def compare_node(node_n, node_b, compare_mode): """ 调用acc_compare.py中的get_accuracy获得精度对比指标 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口 Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list """ - merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode) - merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode) - result = [] dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) + merge_n = _parse_node(node_n, dump_mode) + merge_b = _parse_node(node_b, dump_mode) + result = [] get_accuracy(result, merge_n, merge_b, dump_mode) return result -def _parse_node(node_id, data_dict, stack_json_data, compare_mode): +def _parse_node(node, dump_mode): """ 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参 """ - dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) - op_parsed_list = read_op(data_dict.get(node_id, {}), node_id) - if node_id in stack_json_data: - op_parsed_list.append( - {'full_op_name': node_id, 'full_info': stack_json_data[node_id]}) - else: - op_parsed_list.append({'full_op_name': node_id, 'full_info': None}) + op_parsed_list = [] + op_parsed_list.extend(node.input_data.values()) + op_parsed_list.extend(node.output_data.values()) result = merge_tensor(op_parsed_list, dump_mode) if not result: result['op_name'] = [] diff --git a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py index 95982658d2..0595a58107 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py @@ -35,13 +35,15 @@ class GraphComparator: self.fuzzy_match = args.fuzzy_match self.pattern = re.compile(r'\.\d+\.') self.is_cross_framework = is_cross_framework + self.parallel_merge = args.parallel_merge if hasattr(args, 'parallel_merge') else False + self.rank_pattern = re.compile(r"_rank\d+") def compare(self): """ 比较函数,初始化结束后单独调用。比较结果写入graph_n """ if self.fuzzy_match: - self._compare_nodes_fuzzy(self.graph_n.root) + self._compare_nodes_fuzzy(self.graph_n.root, False if self.parallel_merge else True) else: self._compare_nodes(self.graph_n.root) self._postcompare() @@ -98,11 +100,12 @@ class GraphComparator: while node_list: compare_single_node(node_list.pop(0)) - def _compare_nodes_fuzzy(self, node_root): + def _compare_nodes_fuzzy(self, node_root, check_shape=True): def compare_single_nodes_fuzzy(node_n): if node_n.op != NodeOp.function_api: # 模块经过模糊匹配 - node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id)) + node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id), + check_shape) if node_b: self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b) # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配 @@ -113,7 +116,7 @@ class GraphComparator: if not api_node_n: continue api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match( - api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id))) + api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)), check_shape) if api_node_b: self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b) node_list.extend(node_n.subnodes) @@ -147,21 +150,26 @@ class GraphComparator: api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标 md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差 """ + def handle_api_collection_index(api_collection_node): + precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \ + else GraphConst.MIN_INDEX_KEY + for api in api_collection_node.subnodes: + precision_index = min(precision_index, + api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \ + if self.ma.compare_mode == GraphConst.MD5_COMPARE \ + else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY)) + api_collection_node.data[GraphConst.JSON_INDEX_KEY] = precision_index + for node in self.graph_n.root.subnodes: - if node.op == NodeOp.api_collection: - precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \ - else GraphConst.MIN_INDEX_KEY - for api in node.subnodes: - precision_index = min(precision_index, - api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \ - if self.ma.compare_mode == GraphConst.MD5_COMPARE \ - else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY)) - node.data[GraphConst.JSON_INDEX_KEY] = precision_index + if node.op == NodeOp.api_collection and node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS): + for sub_node in node.subnodes: + handle_api_collection_index(sub_node) + handle_api_collection_index(node) + elif node.op == NodeOp.api_collection: + handle_api_collection_index(node) def _get_and_add_result(self, node_n, node_b): - compare_result_list = compare_node([node_n.id, node_b.id], - [self.data_n_dict, self.data_b_dict], - self.stack_json_data, self.ma.compare_mode) + compare_result_list = compare_node(node_n, node_b, self.ma.compare_mode) if compare_result_list: self.ma.add_csv_data(compare_result_list) self.add_compare_result_to_node(node_n, compare_result_list) @@ -178,6 +186,8 @@ class GraphComparator: if sub_node.op == NodeOp.function_api: # 忽略dump调用次数 count_removed_id = self.pattern.sub(Const.SEP, sub_node.id) + if self.rank_pattern.search(count_removed_id): + count_removed_id = self.rank_pattern.sub('', count_removed_id) node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1 # 赋予模块中的调用顺序 recount_node_id = count_removed_id + str(node_count.get(count_removed_id)) diff --git a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py index dee8618058..96a16eb8f0 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py @@ -36,6 +36,8 @@ class BaseNode: self.overflow_level = None self.matched_distributed = {} self.batch_p2p_info = [] + self.rank = 0 + self.parallel_merge_info = [] def __str__(self): info = f'id:\t{self.id}' @@ -107,6 +109,8 @@ class BaseNode: result['data'] = self.data if self.matched_distributed: result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed + if self.parallel_merge_info: + result['parallel_merge_info'] = self.parallel_merge_info return result def get_ancestors(self): diff --git a/debug/accuracy_tools/msprobe/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/visualization/graph/graph.py index 5bcad6446c..f4caec221f 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/graph.py @@ -65,8 +65,10 @@ class Graph: return node_b, ancestors_n, ancestors_b @staticmethod - def fuzzy_match(node_n, node_b): - if not node_n or not node_b or not node_n.fuzzy_eq(node_b): + def fuzzy_match(node_n, node_b, check_shape=True): + if not node_n or not node_b: + return None, [], [] + if check_shape and not node_n.fuzzy_eq(node_b): return None, [], [] ancestors_n = node_n.get_ancestors() ancestors_b = node_b.get_ancestors() diff --git a/debug/accuracy_tools/msprobe/visualization/graph_service.py b/debug/accuracy_tools/msprobe/visualization/graph_service.py index a9f7870bea..e4d8e077fe 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/visualization/graph_service.py @@ -22,7 +22,8 @@ from msprobe.core.common.file_utils import (check_file_type, create_directory, F from msprobe.core.common.const import FileCheckConst, Const from msprobe.core.common.utils import CompareException, get_dump_mode from msprobe.visualization.compare.graph_comparator import GraphComparator -from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs +from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs, load_parallel_param, \ + sort_rank_number_strings, check_whether_parallel_merge, validate_parallel_param, extract_rank_number from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo from msprobe.core.common.log import logger from msprobe.visualization.graph.node_colors import NodeColors @@ -30,6 +31,7 @@ from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_map from msprobe.core.compare.utils import check_and_return_dir_contents from msprobe.core.common.utils import detect_framework_by_dump_json from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer +from msprobe.visualization.builder.graph_merger import GraphMerger current_time = time.strftime("%Y%m%d%H%M%S") @@ -101,14 +103,15 @@ def _export_compare_graph_result(args, result): return output_file_name -def _build_graph_info(dump_path, args): +def _build_graph_info(dump_path, args, graph=None): construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack) + if not graph: + graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack) return GraphInfo(graph, construct_path, data_path, stack_path) @@ -298,11 +301,12 @@ def _compare_graph_steps(input_param, args): input_param['npu_path'] = os.path.join(dump_step_n, folder_step) input_param['bench_path'] = os.path.join(dump_step_b, folder_step) - _compare_graph_ranks(input_param, args, step=folder_step) + _compare_graph_ranks(input_param, args, step=folder_step) if not args.parallel_merge \ + else _compare_graph_ranks_parallel(input_param, args, step=folder_step) def _build_graph_ranks(dump_ranks_path, args, step=None): - ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK)) + ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_ranks_path, Const.RANK)) serializable_args = SerializableArgs(args) with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: def err_call(err): @@ -319,13 +323,20 @@ def _build_graph_ranks(dump_ranks_path, args, step=None): error_callback=err_call)) build_graph_results = [task.get() for task in build_graph_tasks] - if len(build_graph_results) > 1: + if args.parallel_params: + validate_parallel_param(args.parallel_params[0], dump_ranks_path) + build_graph_results = GraphMerger(build_graph_results, args.parallel_params[0]).merge_graph() + + if len(build_graph_results) > 1 and not args.parallel_merge: DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, args.overflow_check).distributed_match() create_directory(args.output_path) export_build_graph_tasks = [] - for result in build_graph_results: + for i, result in enumerate(build_graph_results): + if args.parallel_params: + result.output_file_name = f'build_{step}_merged{i}_{current_time}.vis' \ + if step else f'build_merged{i}_{current_time}.vis' export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result, args=(serializable_args, result), error_callback=err_call)) @@ -337,7 +348,6 @@ def _build_graph_ranks(dump_ranks_path, args, step=None): logger.info(f'Successfully exported build graph results.') - def _build_graph_steps(dump_steps_path, args): steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP)) for step in steps: @@ -346,6 +356,76 @@ def _build_graph_steps(dump_steps_path, args): _build_graph_ranks(dump_ranks_path, args, step) +def _compare_and_export_graph(graph_task_info, input_param, args, output_file_name): + result = _run_graph_compare(graph_task_info, input_param, args, output_file_name) + return _export_compare_graph_result(args, result) + + +def _compare_graph_ranks_parallel(input_param, args, step=None): + args.fuzzy_match = True + npu_path = input_param.get('npu_path') + bench_path = input_param.get('bench_path') + ranks_n = sort_rank_number_strings(check_and_return_dir_contents(npu_path, Const.RANK)) + ranks_b = sort_rank_number_strings(check_and_return_dir_contents(bench_path, Const.RANK)) + parallel_params = load_parallel_param(input_param) + if len(parallel_params) != 2: + raise RuntimeError('Parallel params error in compare graph!') + validate_parallel_param(parallel_params[0], npu_path) + validate_parallel_param(parallel_params[1], bench_path, '[Bench]') + serializable_args = SerializableArgs(args) + + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') + try: + pool.close() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + # 1.并行构图 + build_graph_tasks_n = [] + build_graph_tasks_b = [] + for rank in ranks_n: + build_graph_tasks_n.append(pool.apply_async(_run_build_graph_single, + args=(npu_path, rank, step, serializable_args), + error_callback=err_call)) + for rank in ranks_b: + build_graph_tasks_b.append(pool.apply_async(_run_build_graph_single, + args=(bench_path, rank, step, serializable_args), + error_callback=err_call)) + graph_results_n = [task.get() for task in build_graph_tasks_n] + graph_results_b = [task.get() for task in build_graph_tasks_b] + + # 2.图合并 + build_graph_results_n = GraphMerger(graph_results_n, parallel_params[0]).merge_graph() + build_graph_results_b = GraphMerger(graph_results_b, parallel_params[1], True).merge_graph() + if len(build_graph_results_n) != len(build_graph_results_b): + raise RuntimeError(f'Parallel merge failed because the dp of npu: {len(build_graph_results_n)} ' + f'is inconsistent with that of bench: {len(build_graph_results_b)}!') + # 3.并行图比对和输出 + export_res_task_list = [] + create_directory(args.output_path) + for i, result_n in enumerate(build_graph_results_n): + graph_n = result_n.graph + graph_b = build_graph_results_b[i].graph + graph_task_info = BuildGraphTaskInfo( + _build_graph_info(os.path.join(npu_path, f'rank{graph_n.root.rank}'), args, graph_n), + _build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b), + f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time) + output_file_name = f'compare_{step}_merged{i}_{current_time}.vis' \ + if step else f'compare_merged{i}_{current_time}.vis' + export_res_task_list.append(pool.apply_async(_compare_and_export_graph, + args=(graph_task_info, input_param, serializable_args, + output_file_name), + error_callback=err_call)) + export_res_list = [res.get() for res in export_res_task_list] + if any(export_res_list): + failed_names = list(filter(lambda x: x, export_res_list)) + logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.') + else: + logger.info('Successfully exported compare graph results.') + + def _graph_service_parser(parser): parser.add_argument("-i", "--input_path", dest="input_path", type=str, help=" The compare input path, a dict json.", required=True) @@ -365,6 +445,8 @@ def _graph_service_command(args): input_param = load_json(args.input_path) npu_path = input_param.get("npu_path") bench_path = input_param.get("bench_path") + args.parallel_merge = check_whether_parallel_merge(input_param) + args.parallel_params = load_parallel_param(input_param) if args.parallel_merge else None check_file_or_directory_path(npu_path, isdir=True) if bench_path: check_file_or_directory_path(bench_path, isdir=True) @@ -386,7 +468,10 @@ def _graph_service_command(args): if content_n != content_b: raise ValueError('The directory structures of npu_path and bench_path are inconsistent.') if content_n == GraphConst.RANKS: - _compare_graph_ranks(input_param, args) + if args.parallel_merge: + _compare_graph_ranks_parallel(input_param, args) + else: + _compare_graph_ranks(input_param, args) elif content_n == GraphConst.STEPS: _compare_graph_steps(input_param, args) else: @@ -427,7 +512,7 @@ class CompareGraphResult: class BuildGraphResult: - def __init__(self, graph, micro_steps, rank=0, output_file_name=''): + def __init__(self, graph, micro_steps=0, rank=0, output_file_name=''): self.graph = graph self.micro_steps = micro_steps self.rank = rank -- Gitee