From a5d1bba627a4f154ab0f71705b6f411c6915f073 Mon Sep 17 00:00:00 2001 From: w00425040 Date: Thu, 13 Mar 2025 21:43:33 +0800 Subject: [PATCH 1/4] =?UTF-8?q?rotary-pos-emb=E7=AE=97=E5=AD=90=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E5=8F=98=E6=9B=B4=EF=BC=8C=E6=9A=82=E6=97=B6=E5=88=A0?= =?UTF-8?q?=E9=99=A4=E8=9E=8D=E5=90=88=E7=AE=97=E5=AD=90=E4=BD=BF=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh b/tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh index 4186513e8..190011adc 100644 --- a/tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh +++ b/tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh @@ -69,7 +69,6 @@ GPT_ARGS=" --position-embedding-type rope \ --normalization RMSNorm \ --swiglu \ - --use-fused-rotary-pos-emb \ --use-fused-swiglu \ --use-flash-attn \ --use-fused-ring-attention-update \ -- Gitee From a4c3dc02e9691f1e3ea3a27c92ab0e0aa94c1cee Mon Sep 17 00:00:00 2001 From: w00425040 Date: Thu, 20 Mar 2025 18:59:16 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E6=B3=A8=E6=98=8E=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=E7=9A=84MindSpeed=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh b/tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh index 24bb4fff7..d09b68642 100644 --- a/tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh +++ b/tests/poc/llama3_405b/pretrain_llama3_405B_8k_256die_A3.sh @@ -1,4 +1,8 @@ #!/bin/bash + +# 需要切换MindSpeed版本 +# git checkout 30c763a22f7afeaf35cb7ef035e61be5f1b9df52 # checkout commit from MindSpeed core_r0.8.0 in 2025.03.04 + export CUDA_DEVICE_MAX_CONNECTIONS=1 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export HCCL_CONNECT_TIMEOUT=2400 -- Gitee From d3478a2f0b5fe9c72e63ffc41c2059cd67ec93ff Mon Sep 17 00:00:00 2001 From: w00425040 Date: Wed, 21 May 2025 14:52:25 +0800 Subject: [PATCH 3/4] =?UTF-8?q?TP2D=E7=89=B9=E6=80=A7=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/transformer/transformer_block.py | 4 - mindspeed_llm/features_manager/__init__.py | 4 +- .../tensor_parallel/__init__.py | 0 .../features_manager/tensor_parallel/tp_2d.py | 97 +++++++++++++++++++ mindspeed_llm/tasks/megatron_adaptor.py | 24 +---- mindspeed_llm/tasks/posttrain/utils.py | 4 - mindspeed_llm/training/arguments.py | 39 -------- 7 files changed, 102 insertions(+), 70 deletions(-) create mode 100644 mindspeed_llm/features_manager/tensor_parallel/__init__.py create mode 100644 mindspeed_llm/features_manager/tensor_parallel/tp_2d.py diff --git a/mindspeed_llm/core/transformer/transformer_block.py b/mindspeed_llm/core/transformer/transformer_block.py index 61f392257..a48e8da56 100644 --- a/mindspeed_llm/core/transformer/transformer_block.py +++ b/mindspeed_llm/core/transformer/transformer_block.py @@ -270,10 +270,6 @@ def transformer_block_forward( if self.post_process and self.post_layer_norm and self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) - if get_args().tp_2d and parallel_state.is_pipeline_last_stage(): - hidden_states = auto_grad_sync_gather_along_first_dim(hidden_states, TPXCollectiveComm) - hidden_states = auto_grad_sync_gather_along_last_dim(hidden_states, TPYCollectiveComm) - return hidden_states diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index 3c035eea0..044b74ecb 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -9,6 +9,7 @@ from mindspeed_llm.features_manager.models.mamba import MambaModel from mindspeed_llm.features_manager.communication.coc import AscendCocFeature from mindspeed_llm.features_manager.common.data import DataFeature from mindspeed_llm.features_manager.communication.coc import AscendCocFeature +from mindspeed_llm.features_manager.tensor_parallel.tp_2d import TP2dFeature FEATURES_LIST = [ # MindSpeed Legacy Features @@ -24,7 +25,8 @@ FEATURES_LIST = [ RotaryPositionEmbeddingFeature(), LanguageModelEmbeddingFeature(), MambaModel(), - AscendCocFeature() + AscendCocFeature(), + TP2dFeature(), # MindSpeed-LLM Legacy Features ] \ No newline at end of file diff --git a/mindspeed_llm/features_manager/tensor_parallel/__init__.py b/mindspeed_llm/features_manager/tensor_parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py b/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py new file mode 100644 index 000000000..4d1c7dff4 --- /dev/null +++ b/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from argparse import ArgumentParser + +from mindspeed.features_manager.feature import MindSpeedFeature + + +class TP2dFeature(MindSpeedFeature): + + def __init__(self): + super().__init__('tp-2d') + + def register_args(self, parser: ArgumentParser): + group = parser.add_argument_group(title=self.feature_name) + group.add_argument('--tp-2d', action='store_true', default=False, + help='use use-2d-tp to replace megatron-style tensor parallel') + group.add_argument('--tp-x', type=int, default=1, + help='the fist dim tensor parallel size for Linear') + group.add_argument('--tp-y', type=int, default=1, + help='the second dim tensor parallel size for Linear') + group.add_argument('--enable-overlap-ag-with-matmul', action='store_true', default=False, + help='use enable-overlap-ag-with-matmul to overlap all-gather with matmul') + group.add_argument('--enable-overlap-matmul-with-rs', action='store_true', default=False, + help='use enable-overlap-matmul-with-rs to overlap matmul with reduce-scatter') + group.add_argument('--enable-backward-overlap-ag-with-matmul', action='store_true', default=False, + help='use enable-backward-overlap-ag-with-matmul to overlap all-gather with matmul in backward') + + def validate_args(self, args): + self.incompatible_check(args, 'sequence_parallel') + self.incompatible_check(args, 'use_fused_rmsnorm') + self.incompatible_check(args, 'use_nanopipe') + self.incompatible_check(args, 'use_ascend_coc') + if getattr(args, self.feature_name, None): + _cp_algo = getattr(args, 'context_parallel_algo', 'ulysses_cp_algo') + if _cp_algo not in ['megatron_cp_algo', 'ulysses_cp_algo']: + raise AssertionError('tp-2d now only support megatron_cp_algo or ulysses_cp_algo') + if not getattr(args, 'use_flash_attn', False) and _cp_algo == 'megatron_cp_algo': + args.context_parallel_algo = 'ulysses_cp_algo' + if args.tensor_model_parallel_size // args.tp_x != args.tp_y: + raise AssertionError('need satisfy tp = tp_x * tp_y') + if args.expert_model_parallel_size > 1: + raise AssertionError('2d tp does not support moe') + + def register_patches(self, patch_manager, args): + if getattr(args, self.feature_name, None): + from mindspeed.core.tensor_parallel.tp_2d.norm_factory import get_norm_tp_2d + patch_manager.register_patch('megatron.legacy.model.utils.get_norm', get_norm_tp_2d) + + from mindspeed.core.tensor_parallel.tp_2d.norm_factory import _allreduce_layernorm_grads_wrapper + patch_manager.register_patch('megatron.core.distributed.finalize_model_grads._allreduce_layernorm_grads', + _allreduce_layernorm_grads_wrapper) + + from mindspeed.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_wrapper + patch_manager.register_patch('megatron.core.models.gpt.gpt_layer_specs._get_mlp_module_spec', + get_mlp_module_spec_wrapper) + + # Embedding patch: mindspeed_llm/features_manager/common/embedding.py + + from mindspeed.core.pipeline_parallel.schedules import get_tensor_shapes_wrapper + patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_tensor_shapes', + get_tensor_shapes_wrapper) + + from mindspeed.core.pipeline_parallel.flexible_schedules import \ + forward_backward_pipelining_with_interleaving_patch + patch_manager.register_patch( + 'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', + forward_backward_pipelining_with_interleaving_patch) + + from mindspeed.core.transformer.transformer_block import transformer_block_forward_wrapper + patch_manager.register_patch('megatron.core.transformer.transformer_block.TransformerBlock.forward', + transformer_block_forward_wrapper) + + from mindspeed.core.transformer.transformer_config import transformer_config_post_init + patch_manager.register_patch('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', + transformer_config_post_init) + + from mindspeed.core.transformer.transformer_config import transformer_config_post_init_wrapper + patch_manager.register_patch('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', + transformer_config_post_init_wrapper) + + from mindspeed.core.transformer.attention import self_attention_init_tp2d_wrapper + patch_manager.register_patch('megatron.core.transformer.attention.SelfAttention.__init__', + self_attention_init_tp2d_wrapper) + + from mindspeed_llm.core.tensor_parallel.tp_2d.parallel_linear_2d import parallell_linear_2D_init_wrapper + patch_manager.register_patch( + "mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d.ParallelLinear2D.__init__", + parallell_linear_2D_init_wrapper) + + self.more_patches_for_tp2d(patch_manager, args) + + def more_patches_for_tp2d(self, patch_manager, args): + """common features""" + # mcore_transformer_adaptation(l2) + from mindspeed.core.transformer.module import megatron_module_init_wrapper + patch_manager.register_patch('megatron.core.transformer.module.MegatronModule.__init__', + megatron_module_init_wrapper) diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 94b68ebdf..1b31a17e3 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -266,10 +266,6 @@ class CoreAdaptation(MegatronAdaptationABC): attention_init) MegatronAdaptation.register('megatron.core.transformer.attention.SelfAttention.__init__', self_attention_init_wrapper) - if MegatronAdaptation.get_args().tp_2d: - from mindspeed.core.transformer.attention import self_attention_init_tp2d_wrapper - MegatronAdaptation.register('megatron.core.transformer.attention.SelfAttention.__init__', - self_attention_init_tp2d_wrapper) MegatronAdaptation.register('megatron.core.transformer.dot_product_attention.DotProductAttention.__init__', dot_product_attention_init) @@ -329,16 +325,11 @@ class CoreAdaptation(MegatronAdaptationABC): from ..core.transformer.transformer_block import _transformer_block_build_layers from ..core.transformer.moe.moe_utils import track_moe_metrics_wrapper + # note: If you upgrade mindspeed to version 0.10.0 or later, replace PTNorm with from mindspeed.core.transformer.custom_layers.transformer_engine import PTNorm from ..core import (PTNorm, topk_router_forward, topk_router_routing, z_loss_func, topk_softmax_with_capacity, get_num_layers_to_build_wrapper, TransformerLayer, topk_router_init_wrapper, transformer_block_init_wrapper, transformer_block_forward, core_mlp_init, topk_router_gating_func) - args = MegatronAdaptation.get_args() - if args.tp_2d: - from mindspeed.core.transformer.transformer_config import transformer_config_post_init - MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', - transformer_config_post_init) - MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', transformer_config_post_init_wrapper) # for mtp @@ -384,6 +375,7 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.__init__', groupedmlp_init_wrapper) # For async log loss + args = MegatronAdaptation.get_args() if args.async_log_allreduce and not args.schedules_method == 'dualpipev': from mindspeed.core.training import train_step MegatronAdaptation.register('megatron.training.training.train_step', train_step) @@ -597,11 +589,6 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.tensor_parallel.layers.RowParallelLinear._load_from_state_dict', parallel_linear_load_from_state_dict_wrapper) - if MegatronAdaptation.get_args().tp_2d: - from mindspeed_llm.core.tensor_parallel.tp_2d.parallel_linear_2d import parallell_linear_2D_init_wrapper - MegatronAdaptation.register( - "mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d.ParallelLinear2D.__init__", - parallell_linear_2D_init_wrapper) def patch_parallel_state(self): import megatron @@ -689,13 +676,6 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', forward_backward_pipelining_with_interleaving_wrapper) - if args.tp_2d: - from mindspeed.core.pipeline_parallel.flexible_schedules import \ - forward_backward_pipelining_with_interleaving_patch - MegatronAdaptation.register( - 'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', - forward_backward_pipelining_with_interleaving_patch) - def patch_swap_optimizer(self): args = MegatronAdaptation.get_args() diff --git a/mindspeed_llm/tasks/posttrain/utils.py b/mindspeed_llm/tasks/posttrain/utils.py index 5bcaab57f..b3c3db4a2 100644 --- a/mindspeed_llm/tasks/posttrain/utils.py +++ b/mindspeed_llm/tasks/posttrain/utils.py @@ -116,10 +116,6 @@ def get_tensor_shapes_decorator(get_tensor_shapes): config=config ) - if args.tp_2d: - tensor_shape = [[tensor_shape[0] // args.tp_x, tensor_shape[1], tensor_shape[2] // args.tp_y] - for tensor_shape in tensor_shape] - return tensor_shape return wrapper diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 5654de2dd..e85cef55b 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -68,7 +68,6 @@ def process_args(parser): parser = _add_mtp_args(parser) parser = _add_rl_args(parser) parser = _add_ndmm_args(parser) - parser = _add_2d_tp_args(parser) parser = _add_hccl_group_buffer_args(parser) parser = _add_default_model_args(parser) parser = _add_megatron2_args(parser) @@ -922,23 +921,6 @@ def _add_ndmm_args(parser): return parser -def _add_2d_tp_args(parser): - group = parser.add_argument_group(title='2d-tp') - group.add_argument('--tp-2d', action='store_true', default=False, - help='use use-2d-tp to replace megatron-style tensor parallel') - group.add_argument('--tp-x', type=int, default=1, - help='the fist dim tensor parallel size for Linear') - group.add_argument('--tp-y', type=int, default=1, - help='the second dim tensor parallel size for Linear') - group.add_argument('--enable-overlap-ag-with-matmul', action='store_true', default=False, - help='use enable-overlap-ag-with-matmul to overlap all-gather with matmul') - group.add_argument('--enable-overlap-matmul-with-rs', action='store_true', default=False, - help='use enable-overlap-matmul-with-rs to overlap matmul with reduce-scatter') - group.add_argument('--enable-backward-overlap-ag-with-matmul', action='store_true', default=False, - help='use enable-backward-overlap-ag-with-matmul to overlap all-gather with matmul in backward') - return parser - - def add_parser_argument_choices_value(parser, argument_name, value): if parser._actions: for action in parser._actions: @@ -1415,26 +1397,6 @@ def _validate_noop_layer(args): args.num_layer_list = None -def _valid_tp_2d_args(args): - if args.tp_2d: - if args.sequence_parallel: - raise AssertionError('2d tp does not support sequence parallel') - if args.use_fused_rmsnorm: - raise AssertionError('2d tp does not support fused rmsnorm') - if hasattr(args, "use_nanopipe") and args.use_nanopipe: - raise AssertionError('tp-2d does not support nano-pipe') - if hasattr(args, "ampipe_degree") and args.ampipe_degree > 1: - raise AssertionError('tp-2d does not support ampipe') - if hasattr(args, "context_parallel_algo") and args.context_parallel_algo not in ['megatron_cp_algo', 'ulysses_cp_algo']: - raise AssertionError('tp-2d now only support megatron_cp_algo or ulysses_cp_algo') - if hasattr(args, "use_ascend_coc") and args.use_ascend_coc: - raise AssertionError('tp-2d does not support ascend coc') - if args.tensor_model_parallel_size // args.tp_x != args.tp_y: - raise AssertionError('need satisfy tp = tp_x * tp_y') - if args.expert_model_parallel_size > 1: - raise AssertionError('2d tp does not support moe') - - def _validate_vpp(args): """validate scenario that vpp is enabled when pp=2.""" if args.pipeline_model_parallel_size != 2 or args.num_layers_per_virtual_pipeline_stage is None: @@ -1580,7 +1542,6 @@ def validate_args_decorator(megatron_validate_args): _validate_dualpipe_args(args) _validate_noop_layer(args) - _valid_tp_2d_args(args) _add_dummy_args(args) # remove in future megatron version _validate_mtp_args(args) -- Gitee From 82a076e171f1100f9fb1cd575e683810aaf4dc8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=BB=E6=B1=9F?= Date: Thu, 22 May 2025 03:22:59 +0000 Subject: [PATCH 4/4] =?UTF-8?q?update=20mindspeed=5Fllm/features=5Fmanager?= =?UTF-8?q?/tensor=5Fparallel/tp=5F2d.py.=20codecheck=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 闻江 --- mindspeed_llm/features_manager/tensor_parallel/tp_2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py b/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py index 4d1c7dff4..65c996e6f 100644 --- a/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py +++ b/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py @@ -91,7 +91,6 @@ class TP2dFeature(MindSpeedFeature): def more_patches_for_tp2d(self, patch_manager, args): """common features""" - # mcore_transformer_adaptation(l2) from mindspeed.core.transformer.module import megatron_module_init_wrapper patch_manager.register_patch('megatron.core.transformer.module.MegatronModule.__init__', megatron_module_init_wrapper) -- Gitee