From 6583443bdcc7bc8ced0fd42ce6ac5e93c31d9c7a Mon Sep 17 00:00:00 2001 From: j00841460 Date: Tue, 8 Jul 2025 02:31:23 +0000 Subject: [PATCH] support Qwen2.5-7B with MLA (multi-head-latent-attention) --- .../qwen25/pretrain_qwen25_7b_8k_pack_ptd.sh | 14 +++++++++-- .../core/models/gpt/gpt_layer_specs.py | 25 ++++++++++++++++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/examples/mcore/qwen25/pretrain_qwen25_7b_8k_pack_ptd.sh b/examples/mcore/qwen25/pretrain_qwen25_7b_8k_pack_ptd.sh index 01788a20c..9609d94ac 100644 --- a/examples/mcore/qwen25/pretrain_qwen25_7b_8k_pack_ptd.sh +++ b/examples/mcore/qwen25/pretrain_qwen25_7b_8k_pack_ptd.sh @@ -17,10 +17,20 @@ TOKENIZER_PATH="your tokenizer path" TP=1 PP=4 -SEQ_LEN=8192 +SEQ_LEN=4096 MBS=1 GBS=64 +MLA_ARGS=" + --multi-head-latent-attention \ + --qk-rope-head-dim 128 \ + --qk-nope-head-dim 256 \ + --q-lora-rank 384 \ + --kv-lora-rank 128 \ + --v-head-dim 256 \ + --qk-layernorm +" + DISTRIBUTED_ARGS=" --nproc_per_node $NPUS_PER_NODE \ --nnodes $NNODES \ @@ -78,7 +88,6 @@ GPT_ARGS=" --no-gradient-accumulation-fusion \ --no-masked-softmax-fusion \ --attention-softmax-in-fp32 \ - --overlap-grad-reduce \ --reuse-fp32-param \ --use-distributed-optimizer \ --reset-position-ids \ @@ -113,4 +122,5 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ $DATA_ARGS \ $CKPT_ARGS \ $OUTPUT_ARGS \ + $MLA_ARGS \ --distributed-backend nccl | tee pretrain_qwen25_7b_8K_pack_ptd.log diff --git a/mindspeed_llm/core/models/gpt/gpt_layer_specs.py b/mindspeed_llm/core/models/gpt/gpt_layer_specs.py index 895755aec..9bcedc57a 100644 --- a/mindspeed_llm/core/models/gpt/gpt_layer_specs.py +++ b/mindspeed_llm/core/models/gpt/gpt_layer_specs.py @@ -22,6 +22,9 @@ from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.training import get_args from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear from mindspeed_llm.core.transformer.transformer_layer import TransformerLayer from mindspeed_llm.core.transformer.custom_layers.transformer_engine import PTNorm @@ -31,7 +34,12 @@ from mindspeed_llm.core.transformer.multi_token_prediction import ( get_mtp_layer_spec, get_mtp_num_layers_to_build, ) - +from mindspeed_llm.tasks.models.transformer.mla_dot_product_attention import MlaDotProductAttention +from mindspeed_llm.tasks.models.transformer.multi_head_latent_attention import ( + MLASelfAttentionSubmodules, + MultiHeadLatentAttention, + LinearNoTP, +) def get_gpt_layer_local_spec_wrapper(fn): @wraps(fn) @@ -44,6 +52,21 @@ def get_gpt_layer_local_spec_wrapper(fn): if qk_layernorm: res.submodules.self_attention.submodules.q_layernorm = PTNorm res.submodules.self_attention.submodules.k_layernorm = PTNorm + + if get_args().multi_head_latent_attention: + res.submodules.self_attention=ModuleSpec( + module=MultiHeadLatentAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_qkv=LinearNoTP, + core_attention=MlaDotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=PTNorm if get_args().qk_layernorm else IdentityOp, + k_layernorm=PTNorm if get_args().qk_layernorm else IdentityOp, + linear_qb=ColumnParallelLinear, + linear_kvb=ColumnParallelLinear, + ) + ) return res return wrapper -- Gitee