diff --git a/mindspeed_llm/core/transformer/transformer_block.py b/mindspeed_llm/core/transformer/transformer_block.py index 61f392257b257f9ad3b96756bcd5de0a3d2be4fc..a48e8da5618589e885a886f6239b29f8cddd9e54 100644 --- a/mindspeed_llm/core/transformer/transformer_block.py +++ b/mindspeed_llm/core/transformer/transformer_block.py @@ -270,10 +270,6 @@ def transformer_block_forward( if self.post_process and self.post_layer_norm and self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) - if get_args().tp_2d and parallel_state.is_pipeline_last_stage(): - hidden_states = auto_grad_sync_gather_along_first_dim(hidden_states, TPXCollectiveComm) - hidden_states = auto_grad_sync_gather_along_last_dim(hidden_states, TPYCollectiveComm) - return hidden_states diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index 3c035eea0428f0f33e75b798d46dae7768e63d3b..044b74ecb0aba384e19a0fe19a06c35dc061bf05 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -9,6 +9,7 @@ from mindspeed_llm.features_manager.models.mamba import MambaModel from mindspeed_llm.features_manager.communication.coc import AscendCocFeature from mindspeed_llm.features_manager.common.data import DataFeature from mindspeed_llm.features_manager.communication.coc import AscendCocFeature +from mindspeed_llm.features_manager.tensor_parallel.tp_2d import TP2dFeature FEATURES_LIST = [ # MindSpeed Legacy Features @@ -24,7 +25,8 @@ FEATURES_LIST = [ RotaryPositionEmbeddingFeature(), LanguageModelEmbeddingFeature(), MambaModel(), - AscendCocFeature() + AscendCocFeature(), + TP2dFeature(), # MindSpeed-LLM Legacy Features ] \ No newline at end of file diff --git a/mindspeed_llm/features_manager/tensor_parallel/__init__.py b/mindspeed_llm/features_manager/tensor_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py b/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..65c996e6f8015ed2398fc32c3421f5b849d3c138 --- /dev/null +++ b/mindspeed_llm/features_manager/tensor_parallel/tp_2d.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +from argparse import ArgumentParser + +from mindspeed.features_manager.feature import MindSpeedFeature + + +class TP2dFeature(MindSpeedFeature): + + def __init__(self): + super().__init__('tp-2d') + + def register_args(self, parser: ArgumentParser): + group = parser.add_argument_group(title=self.feature_name) + group.add_argument('--tp-2d', action='store_true', default=False, + help='use use-2d-tp to replace megatron-style tensor parallel') + group.add_argument('--tp-x', type=int, default=1, + help='the fist dim tensor parallel size for Linear') + group.add_argument('--tp-y', type=int, default=1, + help='the second dim tensor parallel size for Linear') + group.add_argument('--enable-overlap-ag-with-matmul', action='store_true', default=False, + help='use enable-overlap-ag-with-matmul to overlap all-gather with matmul') + group.add_argument('--enable-overlap-matmul-with-rs', action='store_true', default=False, + help='use enable-overlap-matmul-with-rs to overlap matmul with reduce-scatter') + group.add_argument('--enable-backward-overlap-ag-with-matmul', action='store_true', default=False, + help='use enable-backward-overlap-ag-with-matmul to overlap all-gather with matmul in backward') + + def validate_args(self, args): + self.incompatible_check(args, 'sequence_parallel') + self.incompatible_check(args, 'use_fused_rmsnorm') + self.incompatible_check(args, 'use_nanopipe') + self.incompatible_check(args, 'use_ascend_coc') + if getattr(args, self.feature_name, None): + _cp_algo = getattr(args, 'context_parallel_algo', 'ulysses_cp_algo') + if _cp_algo not in ['megatron_cp_algo', 'ulysses_cp_algo']: + raise AssertionError('tp-2d now only support megatron_cp_algo or ulysses_cp_algo') + if not getattr(args, 'use_flash_attn', False) and _cp_algo == 'megatron_cp_algo': + args.context_parallel_algo = 'ulysses_cp_algo' + if args.tensor_model_parallel_size // args.tp_x != args.tp_y: + raise AssertionError('need satisfy tp = tp_x * tp_y') + if args.expert_model_parallel_size > 1: + raise AssertionError('2d tp does not support moe') + + def register_patches(self, patch_manager, args): + if getattr(args, self.feature_name, None): + from mindspeed.core.tensor_parallel.tp_2d.norm_factory import get_norm_tp_2d + patch_manager.register_patch('megatron.legacy.model.utils.get_norm', get_norm_tp_2d) + + from mindspeed.core.tensor_parallel.tp_2d.norm_factory import _allreduce_layernorm_grads_wrapper + patch_manager.register_patch('megatron.core.distributed.finalize_model_grads._allreduce_layernorm_grads', + _allreduce_layernorm_grads_wrapper) + + from mindspeed.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_wrapper + patch_manager.register_patch('megatron.core.models.gpt.gpt_layer_specs._get_mlp_module_spec', + get_mlp_module_spec_wrapper) + + # Embedding patch: mindspeed_llm/features_manager/common/embedding.py + + from mindspeed.core.pipeline_parallel.schedules import get_tensor_shapes_wrapper + patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_tensor_shapes', + get_tensor_shapes_wrapper) + + from mindspeed.core.pipeline_parallel.flexible_schedules import \ + forward_backward_pipelining_with_interleaving_patch + patch_manager.register_patch( + 'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', + forward_backward_pipelining_with_interleaving_patch) + + from mindspeed.core.transformer.transformer_block import transformer_block_forward_wrapper + patch_manager.register_patch('megatron.core.transformer.transformer_block.TransformerBlock.forward', + transformer_block_forward_wrapper) + + from mindspeed.core.transformer.transformer_config import transformer_config_post_init + patch_manager.register_patch('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', + transformer_config_post_init) + + from mindspeed.core.transformer.transformer_config import transformer_config_post_init_wrapper + patch_manager.register_patch('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', + transformer_config_post_init_wrapper) + + from mindspeed.core.transformer.attention import self_attention_init_tp2d_wrapper + patch_manager.register_patch('megatron.core.transformer.attention.SelfAttention.__init__', + self_attention_init_tp2d_wrapper) + + from mindspeed_llm.core.tensor_parallel.tp_2d.parallel_linear_2d import parallell_linear_2D_init_wrapper + patch_manager.register_patch( + "mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d.ParallelLinear2D.__init__", + parallell_linear_2D_init_wrapper) + + self.more_patches_for_tp2d(patch_manager, args) + + def more_patches_for_tp2d(self, patch_manager, args): + """common features""" + from mindspeed.core.transformer.module import megatron_module_init_wrapper + patch_manager.register_patch('megatron.core.transformer.module.MegatronModule.__init__', + megatron_module_init_wrapper) diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 94b68ebdff23b16c74b736c5572996e0e3b2e581..1b31a17e335fc35b97209cc616b6c9cc17ef9c69 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -266,10 +266,6 @@ class CoreAdaptation(MegatronAdaptationABC): attention_init) MegatronAdaptation.register('megatron.core.transformer.attention.SelfAttention.__init__', self_attention_init_wrapper) - if MegatronAdaptation.get_args().tp_2d: - from mindspeed.core.transformer.attention import self_attention_init_tp2d_wrapper - MegatronAdaptation.register('megatron.core.transformer.attention.SelfAttention.__init__', - self_attention_init_tp2d_wrapper) MegatronAdaptation.register('megatron.core.transformer.dot_product_attention.DotProductAttention.__init__', dot_product_attention_init) @@ -329,16 +325,11 @@ class CoreAdaptation(MegatronAdaptationABC): from ..core.transformer.transformer_block import _transformer_block_build_layers from ..core.transformer.moe.moe_utils import track_moe_metrics_wrapper + # note: If you upgrade mindspeed to version 0.10.0 or later, replace PTNorm with from mindspeed.core.transformer.custom_layers.transformer_engine import PTNorm from ..core import (PTNorm, topk_router_forward, topk_router_routing, z_loss_func, topk_softmax_with_capacity, get_num_layers_to_build_wrapper, TransformerLayer, topk_router_init_wrapper, transformer_block_init_wrapper, transformer_block_forward, core_mlp_init, topk_router_gating_func) - args = MegatronAdaptation.get_args() - if args.tp_2d: - from mindspeed.core.transformer.transformer_config import transformer_config_post_init - MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', - transformer_config_post_init) - MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', transformer_config_post_init_wrapper) # for mtp @@ -384,6 +375,7 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.__init__', groupedmlp_init_wrapper) # For async log loss + args = MegatronAdaptation.get_args() if args.async_log_allreduce and not args.schedules_method == 'dualpipev': from mindspeed.core.training import train_step MegatronAdaptation.register('megatron.training.training.train_step', train_step) @@ -597,11 +589,6 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.tensor_parallel.layers.RowParallelLinear._load_from_state_dict', parallel_linear_load_from_state_dict_wrapper) - if MegatronAdaptation.get_args().tp_2d: - from mindspeed_llm.core.tensor_parallel.tp_2d.parallel_linear_2d import parallell_linear_2D_init_wrapper - MegatronAdaptation.register( - "mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d.ParallelLinear2D.__init__", - parallell_linear_2D_init_wrapper) def patch_parallel_state(self): import megatron @@ -689,13 +676,6 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', forward_backward_pipelining_with_interleaving_wrapper) - if args.tp_2d: - from mindspeed.core.pipeline_parallel.flexible_schedules import \ - forward_backward_pipelining_with_interleaving_patch - MegatronAdaptation.register( - 'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', - forward_backward_pipelining_with_interleaving_patch) - def patch_swap_optimizer(self): args = MegatronAdaptation.get_args() diff --git a/mindspeed_llm/tasks/posttrain/utils.py b/mindspeed_llm/tasks/posttrain/utils.py index 5bcaab57f61e0f6e039b7026c20323b29b56ef48..b3c3db4a2f8dbf1e556267aa92454d3a0a907b6d 100644 --- a/mindspeed_llm/tasks/posttrain/utils.py +++ b/mindspeed_llm/tasks/posttrain/utils.py @@ -116,10 +116,6 @@ def get_tensor_shapes_decorator(get_tensor_shapes): config=config ) - if args.tp_2d: - tensor_shape = [[tensor_shape[0] // args.tp_x, tensor_shape[1], tensor_shape[2] // args.tp_y] - for tensor_shape in tensor_shape] - return tensor_shape return wrapper diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 5654de2dddfe0ab4ee4a71990b4c8c8e79a5498f..e85cef55b5b6b14fa154f5da277d36267891d157 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -68,7 +68,6 @@ def process_args(parser): parser = _add_mtp_args(parser) parser = _add_rl_args(parser) parser = _add_ndmm_args(parser) - parser = _add_2d_tp_args(parser) parser = _add_hccl_group_buffer_args(parser) parser = _add_default_model_args(parser) parser = _add_megatron2_args(parser) @@ -922,23 +921,6 @@ def _add_ndmm_args(parser): return parser -def _add_2d_tp_args(parser): - group = parser.add_argument_group(title='2d-tp') - group.add_argument('--tp-2d', action='store_true', default=False, - help='use use-2d-tp to replace megatron-style tensor parallel') - group.add_argument('--tp-x', type=int, default=1, - help='the fist dim tensor parallel size for Linear') - group.add_argument('--tp-y', type=int, default=1, - help='the second dim tensor parallel size for Linear') - group.add_argument('--enable-overlap-ag-with-matmul', action='store_true', default=False, - help='use enable-overlap-ag-with-matmul to overlap all-gather with matmul') - group.add_argument('--enable-overlap-matmul-with-rs', action='store_true', default=False, - help='use enable-overlap-matmul-with-rs to overlap matmul with reduce-scatter') - group.add_argument('--enable-backward-overlap-ag-with-matmul', action='store_true', default=False, - help='use enable-backward-overlap-ag-with-matmul to overlap all-gather with matmul in backward') - return parser - - def add_parser_argument_choices_value(parser, argument_name, value): if parser._actions: for action in parser._actions: @@ -1415,26 +1397,6 @@ def _validate_noop_layer(args): args.num_layer_list = None -def _valid_tp_2d_args(args): - if args.tp_2d: - if args.sequence_parallel: - raise AssertionError('2d tp does not support sequence parallel') - if args.use_fused_rmsnorm: - raise AssertionError('2d tp does not support fused rmsnorm') - if hasattr(args, "use_nanopipe") and args.use_nanopipe: - raise AssertionError('tp-2d does not support nano-pipe') - if hasattr(args, "ampipe_degree") and args.ampipe_degree > 1: - raise AssertionError('tp-2d does not support ampipe') - if hasattr(args, "context_parallel_algo") and args.context_parallel_algo not in ['megatron_cp_algo', 'ulysses_cp_algo']: - raise AssertionError('tp-2d now only support megatron_cp_algo or ulysses_cp_algo') - if hasattr(args, "use_ascend_coc") and args.use_ascend_coc: - raise AssertionError('tp-2d does not support ascend coc') - if args.tensor_model_parallel_size // args.tp_x != args.tp_y: - raise AssertionError('need satisfy tp = tp_x * tp_y') - if args.expert_model_parallel_size > 1: - raise AssertionError('2d tp does not support moe') - - def _validate_vpp(args): """validate scenario that vpp is enabled when pp=2.""" if args.pipeline_model_parallel_size != 2 or args.num_layers_per_virtual_pipeline_stage is None: @@ -1580,7 +1542,6 @@ def validate_args_decorator(megatron_validate_args): _validate_dualpipe_args(args) _validate_noop_layer(args) - _valid_tp_2d_args(args) _add_dummy_args(args) # remove in future megatron version _validate_mtp_args(args)