diff --git a/examples/mcore/qwen3_moe/generate_qwen3_30b_a3b_ptd_mamba.sh b/examples/mcore/qwen3_moe/generate_qwen3_30b_a3b_ptd_mamba.sh new file mode 100644 index 0000000000000000000000000000000000000000..c33ceda359a0db3d395bb2dc69981bb28146f168 --- /dev/null +++ b/examples/mcore/qwen3_moe/generate_qwen3_30b_a3b_ptd_mamba.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +# The number of parameters is not aligned +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +# please fill these path configurations +TOKENIZER_PATH="/home/ascend-vllm/model/Qwen3-30B-A3B" +CHECKPOINT="/home/ascend-vllm/model/Qwen3-30B-A3B-Mamba2-v4-tp1-pp8-ep1" + + +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +NPUS_PER_NODE=8 +WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES)) + +TP=1 +PP=8 +EP=1 +SEQ_LENGTH=2048 +ROUTER_BALANCING_TYPE='softmax_topk' + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +MOE_ARGS=" + --num-experts 128 \ + --moe-router-topk 8 \ + --moe-router-load-balancing-type ${ROUTER_BALANCING_TYPE} \ + --moe-intermediate-size 768 \ + --moe-permutation-async-comm \ + --moe-token-dispatcher-type allgather \ + --moe-aux-loss-coeff 0.001 +" + + +NUM_LAYERS=96 +LAYER_PATTEN="*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-*-" +# NUM_LAYERS=8 +# LAYER_PATTEN="*-M-*-M-" +MAMBA_ARGS=" + --reuse-fp32-param \ + --no-shared-storage \ + --use-distributed-optimizer \ + --use-flash-attn \ + --use-mcore-models \ + --num-layers ${NUM_LAYERS} \ + --mamba-ngroups 4 \ + --mamba-chunk-size 128 \ + --mamba-d-state 128 \ + --mamba-d-conv 4 \ + --mamba-expand 2 \ + --mamba-headdim 128 \ + --tokenizer-model ${TOKENIZER_PATH} \ + --hybrid-attention-ratio 0.26 \ + --hybrid-mlp-ratio 0.5 \ + --hybrid-override-pattern $LAYER_PATTEN \ + --untie-embeddings-and-output-weights \ + --overlap-param-gather \ + --overlap-grad-reduce \ + --norm-epsilon 1e-6 \ +" + +torchrun $DISTRIBUTED_ARGS inference_mamba.py \ + $MOE_ARGS \ + $MAMBA_ARGS \ + --use-mcore-models \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --expert-model-parallel-size ${EP} \ + --load ${CHECKPOINT} \ + --moe-grouped-gemm \ + --norm-topk-prob \ + --spec mindspeed_llm.tasks.models.spec.qwen3_mamba_spec layer_spec \ + --kv-channels 128 \ + --qk-layernorm \ + --num-layers ${NUM_LAYERS} \ + --hidden-size 2048 \ + --use-rotary-position-embeddings \ + --num-attention-heads 32 \ + --ffn-hidden-size 8192 \ + --max-position-embeddings 40960 \ + --seq-length ${SEQ_LENGTH} \ + --make-vocab-size-divisible-by 1 \ + --padded-vocab-size 151936 \ + --rotary-base 1000000 \ + --untie-embeddings-and-output-weights \ + --micro-batch-size 1 \ + --disable-bias-linear \ + --swiglu \ + --use-fused-swiglu \ + --use-fused-rmsnorm \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_PATH} \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --norm-epsilon 1e-6 \ + --hidden-dropout 0 \ + --attention-dropout 0 \ + --tokenizer-not-use-fast \ + --max-new-tokens 256 \ + --no-gradient-accumulation-fusion \ + --attention-softmax-in-fp32 \ + --exit-on-missing-checkpoint \ + --no-masked-softmax-fusion \ + --group-query-attention \ + --num-query-groups 4 \ + --seed 42 \ + --bf16 \ + | tee logs/generate_mcore_qwen3_30b_a3b.log diff --git a/examples/mcore/qwen3_moe/pretrain_qwen3_30b_a3b_4K_ptd_mamba.sh b/examples/mcore/qwen3_moe/pretrain_qwen3_30b_a3b_4K_ptd_mamba.sh new file mode 100644 index 0000000000000000000000000000000000000000..d3a9b2ab655611b8d8ebd8e36ee8b0b460737f4d --- /dev/null +++ b/examples/mcore/qwen3_moe/pretrain_qwen3_30b_a3b_4K_ptd_mamba.sh @@ -0,0 +1,183 @@ +#!/bin/bash + +export HCCL_CONNECT_TIMEOUT=1800 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export NPU_ASD_ENABLE=0 + +NPUS_PER_NODE=8 +MASTER_ADDR=7.150.14.181 +MASTER_PORT=6000 +NNODES=2 +NODE_RANK=1 +WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES)) + +# please fill these path configurations +CKPT_SAVE_DIR="your model save ckpt path" +DATA_PATH="your data path" +TOKENIZER_PATH="your tokenizer path" +CKPT_LOAD_DIR="your model ckpt path" + +DATA_PATH="/home/ascend-vllm/dataset/lsb/enwiki20230101/Qwen3-30B-A3B-convert-pretrain_text_document" +TOKENIZER_PATH="/home/ascend-vllm/model/Qwen3-30B-A3B" +CKPT_SAVE_DIR="/home/ascend-vllm/model/Qwen3-30B-A3B-mamba" +CKPT_LOAD_DIR="/home/ascend-vllm/model/Qwen3-30B-A3B-Mamba2-v3-tp1-pp4-ep4" + +TP=1 +PP=4 +EP=4 +CP=1 + +MBS=1 +GBS=256 +SEQ_LENGTH=4096 +TRAIN_ITERS=2000 +CP_TYPE='ulysses_cp_algo' +ROUTER_BALANCING_TYPE='aux_loss' + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +NUM_LAYERS=96 +LAYER_PATTEN="*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-M-*-*-" +# NUM_LAYERS=8 +# LAYER_PATTEN="*-M-*-M-" +MAMBA_ARGS=" + --reuse-fp32-param \ + --no-shared-storage \ + --use-distributed-optimizer \ + --use-flash-attn \ + --use-mcore-models \ + --num-layers ${NUM_LAYERS} \ + --mamba-ngroups 4 \ + --mamba-chunk-size 128 \ + --mamba-d-state 128 \ + --mamba-d-conv 4 \ + --mamba-expand 2 \ + --mamba-headdim 128 \ + --tokenizer-model ${TOKENIZER_PATH} \ + --hybrid-attention-ratio 0.26 \ + --hybrid-mlp-ratio 0.5 \ + --hybrid-override-pattern $LAYER_PATTEN \ + --untie-embeddings-and-output-weights \ + --overlap-param-gather \ + --overlap-grad-reduce \ + --norm-epsilon 1e-6 \ +" + +MOE_ARGS=" + --num-experts 128 \ + --moe-router-topk 8 \ + --moe-router-load-balancing-type ${ROUTER_BALANCING_TYPE} \ + --moe-intermediate-size 768 \ + --moe-grouped-gemm \ + --use-fused-moe-token-permute-and-unpermute \ + --moe-permutation-async-comm \ + --moe-token-dispatcher-type alltoall \ + --moe-aux-loss-coeff 0.001 \ +" + +OPTIMIZE_ARGS=" + --use-flash-attn \ + --use-fused-rotary-pos-emb \ + --sequence-parallel \ + --use-rotary-position-embeddings \ + --use-fused-swiglu \ + --use-fused-rmsnorm \ + --no-masked-softmax-fusion \ + --use-distributed-optimizer + --gemm-gradient-accumulation-fusion \ + --recompute-method uniform \ + --recompute-granularity full \ + --recompute-num-layers 1 \ +" + +TRAIN_ARGS=" + --micro-batch-size ${MBS} \ + --global-batch-size ${GBS} \ + --lr 1.25e-6 \ + --lr-decay-style cosine \ + --min-lr 1.25e-7 \ + --weight-decay 1e-1 \ + --lr-warmup-fraction 0.01 \ + --attention-dropout 0.0 \ + --init-method-std 0.01 \ + --hidden-dropout 0.0 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --initial-loss-scale 4096 \ + --seed 42 \ + --bf16 \ + --train-iters ${TRAIN_ITERS} \ + --seq-length ${SEQ_LENGTH} \ + --no-shared-storage +" + +MODEL_PARALLEL_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --expert-model-parallel-size ${EP} \ + --context-parallel-size ${CP} \ + --context-parallel-algo ${CP_TYPE} \ +" + +GPT_ARGS=" + --use-mcore-models \ + --spec mindspeed_llm.tasks.models.spec.qwen3_mamba_spec layer_spec \ + --kv-channels 128 \ + --qk-layernorm \ + --norm-topk-prob \ + --tokenizer-name-or-path ${TOKENIZER_PATH} \ + --max-position-embeddings ${SEQ_LENGTH} \ + --num-layers 48 \ + --hidden-size 2048 \ + --ffn-hidden-size 6144 \ + --num-attention-heads 32 \ + --tokenizer-type PretrainedFromHF \ + --make-vocab-size-divisible-by 1 \ + --padded-vocab-size 151936 \ + --rotary-base 1000000 \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --swiglu \ + --attention-softmax-in-fp32 \ + --no-gradient-accumulation-fusion \ + --group-query-attention \ + --num-query-groups 4 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval ${TRAIN_ITERS} \ + --eval-interval ${TRAIN_ITERS} \ + --eval-iters 0 \ + --no-load-optim \ + --no-load-rng +" + +torchrun $DISTRIBUTED_ARGS pretrain_mamba.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $MOE_ARGS \ + $OUTPUT_ARGS \ + $OPTIMIZE_ARGS \ + $TRAIN_ARGS \ + $MODEL_PARALLEL_ARGS \ + $MAMBA_ARGS \ + --distributed-backend nccl \ + --load ${CKPT_LOAD_DIR} \ + --save ${CKPT_SAVE_DIR} \ + | tee logs/train_mcore_qwen3_30b_a3b.log diff --git a/inference_mamba.py b/inference_mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4b8964f2bd51096a684814e39c4fae6ed58964 --- /dev/null +++ b/inference_mamba.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +from mindspeed_llm import megatron_adaptor +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, \ + get_gpt_layer_local_spec +from megatron.core.transformer.spec_utils import import_module +from megatron.training import get_args, print_rank_0 +from megatron.legacy.model import GPTModel +from megatron.training.initialize import initialize_megatron +from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.yaml_arguments import core_transformer_config_from_yaml + +from mindspeed_llm.tasks.inference.infer_base import task_factory +from mindspeed_llm.tasks.inference.module import GPTModelInfer, MambaModelInfer, MegatronModuleForCausalLM +from megatron.core.inference_params import InferenceParams + + + +def model_provider(pre_process=True, post_process=True) -> Union[MambaModelInfer, GPTModel]: + """Builds the model. + + If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + Union[GPTModelInfer, GPTModel]: The returned model + """ + args = get_args() + use_te = args.transformer_impl == "transformer_engine" + + if args.sequence_parallel and args.use_kv_cache: + raise AssertionError('Use_kv_cache can not be true in sequence_parallel mode.') + + print_rank_0('building GPT model ...') + # Experimental loading arguments from yaml + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + config = core_transformer_config_from_args(args) + + if args.spec is not None: + mamba_stack_spec = import_module(args.spec) + else: + raise "You must provide a valid Mamba layer spec!" + + if args.use_mcore_models: + + model = MambaModelInfer( + config=config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + mamba_ssm_ngroups=args.mamba_ngroups, + pre_process=pre_process, + hybrid_attention_ratio=args.hybrid_attention_ratio, + hybrid_mlp_ratio=args.hybrid_mlp_ratio, + hybrid_override_pattern=args.hybrid_override_pattern, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base + ) + + else: + if not args.context_parallel_size == 1: + raise ValueError("Context parallelism is only supported with Megatron Core!") + + model = GPTModel( + config, + parallel_output=True if args.sequence_parallel else False, + pre_process=pre_process, + post_process=post_process + ) + + return model + + +def main(): + initialize_megatron(args_defaults={'no_load_rng': True, + 'no_load_optim': True}) + + args = get_args() + + model = MegatronModuleForCausalLM.from_pretrained( + model_provider=model_provider, + pretrained_model_name_or_path=args.load + ) + + task_factory(args, model) + + + # # 生成指定输入 + # import torch + # import numpy as np + # from megatron.training.utils import get_ltor_masks_and_position_ids + # input_ids = torch.tensor([i for i in range(10000, 12048)]).unsqueeze(0).npu() + # eod = 0 + # reset_position_ids = False + # reset_attention_mask = False + # eod_mask_loss = False + # max_batch_size = 1 + # max_sequence_length = 2048 + # attention_mask, loss_mask, _ = get_ltor_masks_and_position_ids( + # input_ids, + # eod, + # reset_position_ids, + # reset_attention_mask, + # eod_mask_loss) + # inference_params = InferenceParams(max_batch_size, max_sequence_length) + # with torch.no_grad(): + # outputs = model.forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask.npu(), inference_params=inference_params) + # print(outputs.shape, outputs.dtype) + # np.save("./npu_forward_out_mg_qewn3_mamba_logits_fp16.npy", outputs.cpu().numpy()) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mindspeed_llm/core/ssm/mamba_block.py b/mindspeed_llm/core/ssm/mamba_block.py index 4c4266b514ca60d6322c020b3a4a96c1fc556f94..3430ff83b0bc17e504b6d726d79f83eb83eee24c 100644 --- a/mindspeed_llm/core/ssm/mamba_block.py +++ b/mindspeed_llm/core/ssm/mamba_block.py @@ -79,6 +79,11 @@ def _mamba_block_method_checkpointed_forward_func( inference_params=None, rotary_pos_emb=rotary_pos_emb, ) + # The attention layer (currently a simplified transformer layer) + # outputs a tuple of (hidden_states, context). Context is intended + # for cross-attention, and is not needed in our model. + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] return hidden_states return custom_forward diff --git a/mindspeed_llm/core/ssm/mamba_mixer.py b/mindspeed_llm/core/ssm/mamba_mixer.py index 3ae0b6f77efd2b65ced3dc36657910ca07c8819e..2d60453781110384206d291e05f9c226a00cfe7f 100644 --- a/mindspeed_llm/core/ssm/mamba_mixer.py +++ b/mindspeed_llm/core/ssm/mamba_mixer.py @@ -22,8 +22,8 @@ def mamba_mixer_init_wrapper(fn): kwargs["expand"] = param_args.mamba_expand kwargs["headdim"] = param_args.mamba_headdim fn(self, *args, **kwargs) - dt_min = kwargs.pop('dt_min', 0.001) - dt_max = kwargs.pop('dt_max', 0.1) + dt_min = kwargs.pop('dt_min', 0.0) + dt_max = kwargs.pop('dt_max', float("inf")) self.use_mem_eff_path = False self.d_ssm = param_args.mamba_d_ssm self.dt_min = dt_min @@ -102,9 +102,9 @@ def mamba_mixer_forward(self, hidden_states, seqlen=None, seq_idx=None, cu_seqle x, B, C = torch.split( xBC, [ - self.d_inner_local, self.ngroups_local * self.d_state, self.ngroups_local * self.d_state, + self.d_inner_local, ], dim=-1, ) @@ -131,7 +131,7 @@ def mamba_mixer_forward(self, hidden_states, seqlen=None, seq_idx=None, cu_seqle ) state_opts = StateOptions( - return_final_state=True if ssm_state else False + return_final_state=True if ssm_state is not None else False ) state_space_duality = StateSpaceProcessor(config=config) y = state_space_duality.process(inputs, state_opts) diff --git a/mindspeed_llm/core/transformer/moe/moe_layer.py b/mindspeed_llm/core/transformer/moe/moe_layer.py index c4997a29c6721bc5c07c9e3ca211788958654cbc..52f3de78e8193d9e8a87448956044e740e2a2ae5 100644 --- a/mindspeed_llm/core/transformer/moe/moe_layer.py +++ b/mindspeed_llm/core/transformer/moe/moe_layer.py @@ -93,6 +93,7 @@ def moe_layer_forward(self, hidden_states: torch.Tensor): # process MoE scores, indices = self.router(hidden_states) + scores = scores / scores.sum(dim=-1, keepdim=True) if global_args.moe_revert_type_after_topk: (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( diff --git a/mindspeed_llm/features_manager/models/mamba.py b/mindspeed_llm/features_manager/models/mamba.py index ac482e58143fa4f15d1abf9e8d183698d54b8fe5..c6cdbe879d076ffc486e88156fe88e4c3b2df98e 100644 --- a/mindspeed_llm/features_manager/models/mamba.py +++ b/mindspeed_llm/features_manager/models/mamba.py @@ -15,7 +15,7 @@ class MambaModel(MindSpeedFeature): group.add_argument('--mamba-d-state', type=int, default=128, help='state dim for mamba') group.add_argument('--mamba-d-conv', type=int, default=4, help='conv channel dim for mamba') group.add_argument('--mamba-expand', type=int, default=1, help='expand scale for mamba') - group.add_argument('--mamba-headdim', type=int, default=80, help='head dim for mamba') + group.add_argument('--mamba-headdim', type=int, default=80, help='head dim for mamba') def register_patches(self, patch_manager, args): from mindspeed_llm.core.ssm.mamba_mixer import mamba_mixer_init_wrapper, mamba_mixer_forward, Mamba2RMSNorm diff --git a/mindspeed_llm/tasks/inference/module.py b/mindspeed_llm/tasks/inference/module.py index cc9da9d483f52d8236ec3df203bb54706ad8b9d1..c4e4d9dcbe472eebb27a04c024acb29b8870d9bc 100644 --- a/mindspeed_llm/tasks/inference/module.py +++ b/mindspeed_llm/tasks/inference/module.py @@ -25,6 +25,7 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron.core.models.gpt.gpt_model import GPTModel from megatron.training import get_args, global_vars from megatron.core import parallel_state, ModelParallelConfig +from megatron.core.models.mamba import MambaModel class MegatronModuleForCausalLMABC(torch.nn.Module, abc.ABC): @@ -501,5 +502,13 @@ class GPTModelInfer(GPTModel): super().__init__(*args, **kwargs) self.infer_model = MegatronModuleForCausalLM() + def generate(self, input_ids=None, **kwargs): + return self.infer_model.generate(input_ids=input_ids, **kwargs) + +class MambaModelInfer(MambaModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.infer_model = MegatronModuleForCausalLM() + def generate(self, input_ids=None, **kwargs): return self.infer_model.generate(input_ids=input_ids, **kwargs) \ No newline at end of file diff --git a/mindspeed_llm/tasks/models/spec/qwen3_mamba_spec.py b/mindspeed_llm/tasks/models/spec/qwen3_mamba_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..008dd51e3da4e4ffb7fd40099b5cf9437c3c4895 --- /dev/null +++ b/mindspeed_llm/tasks/models/spec/qwen3_mamba_spec.py @@ -0,0 +1,157 @@ +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec +from megatron.training import get_args + +from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules +from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules + +from megatron.core.transformer import ModuleSpec, TransformerLayer, TransformerLayerSubmodules +from mindspeed_llm.core.transformer.custom_layers.transformer_engine import PTNorm + +args = get_args() +num_experts, moe_grouped_gemm, qk_layernorm = args.num_experts, args.moe_grouped_gemm, args.qk_layernorm + +# # Transformer Layer Spec for Gemma using post_mlp_layernorm and post_mlp_layernorm. +# layer_spec = ModuleSpec( +# module=TransformerLayer, +# submodules=TransformerLayerSubmodules( +# input_layernorm=PTNorm, +# self_attention=ModuleSpec( +# module=SelfAttention, +# params={"attn_mask_type": AttnMaskType.causal}, +# submodules=SelfAttentionSubmodules( +# linear_qkv=ColumnParallelLinear, +# core_attention=DotProductAttention, +# linear_proj=RowParallelLinear, +# q_layernorm=PTNorm if qk_layernorm else IdentityOp, +# k_layernorm=PTNorm if qk_layernorm else IdentityOp, +# ), +# ), +# self_attn_bda=get_bias_dropout_add, +# pre_mlp_layernorm=PTNorm, +# mlp=_get_mlp_module_spec( +# use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm +# ), +# mlp_bda=get_bias_dropout_add, +# sharded_state_dict_keys_map={ +# 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', +# 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', +# }, +# ), +# ) + +# layer_spec = ModuleSpec( +# module=MambaStack, +# submodules=MambaStackSubmodules( +# mamba_layer=ModuleSpec( +# module=MambaLayer, +# submodules=MambaLayerSubmodules( +# norm=PTNorm, +# mixer=ModuleSpec( +# module=MambaMixer, +# submodules=MambaMixerSubmodules( +# in_proj=ColumnParallelLinear, +# out_proj=RowParallelLinear, +# ), +# ), +# mamba_bda=get_bias_dropout_add, +# ), +# ), +# attention_layer=ModuleSpec( +# module=TransformerLayer, +# submodules=TransformerLayerSubmodules( +# input_layernorm=PTNorm, +# self_attention=ModuleSpec( +# module=SelfAttention, +# params={"attn_mask_type": AttnMaskType.causal}, +# submodules=SelfAttentionSubmodules( +# linear_qkv=ColumnParallelLinear, +# core_attention=DotProductAttention, +# linear_proj=RowParallelLinear, +# ), +# ), +# self_attn_bda=get_bias_dropout_add, +# ), +# ), +# mlp_layer=ModuleSpec( +# module=TransformerLayer, +# submodules=TransformerLayerSubmodules( +# pre_mlp_layernorm=PTNorm, +# mlp=ModuleSpec( +# module=MLP, +# submodules=MLPSubmodules( +# linear_fc1=ColumnParallelLinear, +# linear_fc2=RowParallelLinear, +# ), +# ), +# mlp_bda=get_bias_dropout_add, +# ), +# ), +# ), +# ) + +layer_spec = ModuleSpec( + module=MambaStack, + submodules=MambaStackSubmodules( + mamba_layer=ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + norm=PTNorm, + mixer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=ColumnParallelLinear, + out_proj=RowParallelLinear, + ), + ), + mamba_bda=get_bias_dropout_add, + ), + ), + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=PTNorm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=PTNorm if qk_layernorm else IdentityOp, + k_layernorm=PTNorm if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), + mlp_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=PTNorm, + # mlp=ModuleSpec( + # module=MLP, + # submodules=MLPSubmodules( + # linear_fc1=ColumnParallelLinear, + # linear_fc2=RowParallelLinear, + # ), + # ), + mlp=_get_mlp_module_spec( + use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm + ), + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + }, + ), + ), + ), +) \ No newline at end of file diff --git a/mindspeed_llm/tasks/models/ssm/state_space_duality.py b/mindspeed_llm/tasks/models/ssm/state_space_duality.py index 809eeffafe0b42ee3b86144352812fe7ce92c86c..c68d60087e1fb4d3404e7abc166416736f63fa4c 100644 --- a/mindspeed_llm/tasks/models/ssm/state_space_duality.py +++ b/mindspeed_llm/tasks/models/ssm/state_space_duality.py @@ -77,12 +77,14 @@ class StateSpaceProcessor: # Dimension transformations x, dt, A, B, C = self._expand_dims(x, A, dt, B, C) - B_exp, C_exp = self._expand_groups_to_heads(B, C) + # B_exp, C_exp = self._expand_groups_to_heads(B, C) + x_exp, B_exp = self._expand_groups_to_heads(x, B) dt_proc = self._process_time_step(dt) - D = self._prepare_residual(D, x, pad_size) + D = self._prepare_residual(D, x_exp, pad_size) # Chunk processing - x_pad, A_pad, B_pad, C_pad = self._chunk_and_pad(x, dt_proc, A, B_exp, C_exp, pad_size) + # x_pad, A_pad, B_pad, C_pad = self._chunk_and_pad(x, dt_proc, A, B_exp, C_exp, pad_size) + x_pad, A_pad, B_pad, C_pad = self._chunk_and_pad(x_exp, dt_proc, A, B_exp, C, pad_size) # Core computations Y_diag, states, A_cum, C_br = self._compute_diagonal_blocks(A_pad, B_pad, C_pad, x_pad) @@ -93,11 +95,13 @@ class StateSpaceProcessor: return self._synthesize_output((Y_diag, Y_off, D), (pad_size, seq_len), state_opts) def _expand_dims(self, x, A, dt, B, C): - x = rearrange(x, "b l (h p) -> b l h p", p=self.config['headdim']).contiguous() + # x = rearrange(x, "b l (h p) -> b l h p", p=self.config['headdim']).contiguous() + C = rearrange(C, "b l (h p) -> b l h p", p=self.config['headdim']).contiguous() dt = dt.contiguous() A = A.contiguous() B = rearrange(B, "b l (g n) -> b l g n", n=self.config['d_state']).contiguous() - C = rearrange(C, "b l (g n) -> b l g n", n=self.config['d_state']).contiguous() + x = rearrange(x, "b l (g n) -> b l g n", n=self.config['d_state']).contiguous() + # C = rearrange(C, "b l (g n) -> b l g n", n=self.config['d_state']).contiguous() return x, dt, A, B, C def _prepare_initial_states(self, states: Optional[torch.Tensor]) -> torch.Tensor: