From 8f2f2dbb859422011d6be3cc3f8c1f80479ed4b5 Mon Sep 17 00:00:00 2001 From: guozhihua Date: Wed, 4 Jun 2025 17:47:30 +0800 Subject: [PATCH] Optimized the performance of mamba2 pretrain --- docs/pytorch/models/ssm_model.md | 9 +++-- .../mamba2/pretrain_mamba2_2.7b_4k_ptd.sh | 1 + .../mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh | 1 + .../pretrain_mamba2_hybrid_8b_4k_ptd.sh | 1 + mindspeed_llm/core/ssm/mamba_mixer.py | 35 +++++++++---------- .../features_manager/models/mamba.py | 4 +-- pretrain_gpt.py | 5 +++ pretrain_mamba.py | 12 ++++--- 8 files changed, 42 insertions(+), 26 deletions(-) diff --git a/docs/pytorch/models/ssm_model.md b/docs/pytorch/models/ssm_model.md index d3ce95f20..155fa8fdf 100644 --- a/docs/pytorch/models/ssm_model.md +++ b/docs/pytorch/models/ssm_model.md @@ -17,7 +17,7 @@ Mamba2 - 2.7B + 2.7B mamba2 4K Mcore @@ -47,4 +47,9 @@ ## 以上模型脚本环境变量声明: -关于脚本的环境变量定义见[environment_variable.md](../features/environment_variable.md)。 \ No newline at end of file +HCCL_CONNECT_TIMEOUT:设置HCCL超时时间,默认值为120
+CUDA_DEVICE_MAX_CONNECTIONS:定义了任务流能够利用或映射到的硬件队列的数量
+PYTORCH_NPU_ALLOC_CONF:内存碎片优化开关,默认是expandable_segments:False,使能时expandable_segments:True
+NPUS_PER_NODE: 配置一个计算节点上使用的NPU数量
+CPU_AFFINITY_CONF: cpu绑核环境变量
+TASK_QUEUE_ENABLE:二级流水下发环境变量
diff --git a/examples/mcore/mamba2/pretrain_mamba2_2.7b_4k_ptd.sh b/examples/mcore/mamba2/pretrain_mamba2_2.7b_4k_ptd.sh index af05b8857..1a7b48c43 100644 --- a/examples/mcore/mamba2/pretrain_mamba2_2.7b_4k_ptd.sh +++ b/examples/mcore/mamba2/pretrain_mamba2_2.7b_4k_ptd.sh @@ -3,6 +3,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export CPU_AFFINITY_CONF=1 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export HCCL_CONNECT_TIMEOUT=3600 +export TASK_QUEUE_ENABLE=2 NPUS_PER_NODE=8 MASTER_ADDR=localhost diff --git a/examples/mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh b/examples/mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh index 8116dc625..618ca64fd 100644 --- a/examples/mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh +++ b/examples/mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh @@ -3,6 +3,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export CPU_AFFINITY_CONF=1 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export HCCL_CONNECT_TIMEOUT=3600 +export TASK_QUEUE_ENABLE=2 NPUS_PER_NODE=8 MASTER_ADDR=localhost diff --git a/examples/mcore/mamba2/pretrain_mamba2_hybrid_8b_4k_ptd.sh b/examples/mcore/mamba2/pretrain_mamba2_hybrid_8b_4k_ptd.sh index e640b1f63..307b091cc 100644 --- a/examples/mcore/mamba2/pretrain_mamba2_hybrid_8b_4k_ptd.sh +++ b/examples/mcore/mamba2/pretrain_mamba2_hybrid_8b_4k_ptd.sh @@ -3,6 +3,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export CPU_AFFINITY_CONF=1 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export HCCL_CONNECT_TIMEOUT=3600 +export TASK_QUEUE_ENABLE=2 NPUS_PER_NODE=8 MASTER_ADDR=localhost diff --git a/mindspeed_llm/core/ssm/mamba_mixer.py b/mindspeed_llm/core/ssm/mamba_mixer.py index 60ed5c717..89b4b01b2 100644 --- a/mindspeed_llm/core/ssm/mamba_mixer.py +++ b/mindspeed_llm/core/ssm/mamba_mixer.py @@ -7,6 +7,7 @@ from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F +import torch_npu from megatron.training import get_args from mindspeed_llm.tasks.models.ssm.state_space_duality import StateSpaceProcessor, ProcessInputs, StateOptions @@ -15,33 +16,20 @@ def mamba_mixer_init_wrapper(fn): @wraps(fn) def wrapper(self, *args, **kwargs): param_args = get_args() - kwargs["rmsnorm"] = False kwargs["chunk_size"] = param_args.mamba_chunk_size kwargs["d_state"] = param_args.mamba_d_state kwargs["d_conv"] = param_args.mamba_d_conv kwargs["expand"] = param_args.mamba_expand kwargs["headdim"] = param_args.mamba_headdim fn(self, *args, **kwargs) - self.rmsnorm = True dt_min = kwargs.pop('dt_min', 0.001) dt_max = kwargs.pop('dt_max', 0.1) - args = get_args() self.use_mem_eff_path = False self.d_ssm = param_args.mamba_d_ssm self.dt_min = dt_min self.dt_max = dt_max self.d_ssm_local = self.d_inner_local if self.d_ssm is None else self.d_ssm // self.tensor_model_parallel_size - - if self.rmsnorm: - self.norm = Mamba2RMSNorm( - self.d_inner_local, - eps=1e-5, - group_size=self.d_inner_local // self.ngroups_local, - norm_before_gate=self.norm_before_gate, - device=torch.cuda.current_device(), - dtype=self.config.params_dtype - ) return wrapper @@ -176,7 +164,7 @@ def mamba_mixer_forward(self, hidden_states, seqlen=None, seq_idx=None, cu_seqle class Mamba2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): + def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None, sequence_parallel: bool = True): """If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). """ @@ -190,6 +178,8 @@ class Mamba2RMSNorm(nn.Module): self.norm_before_gate = norm_before_gate self.reset_parameters() + setattr(self.weight, 'sequence_parallel', sequence_parallel) + def reset_parameters(self): torch.nn.init.ones_(self.weight) @@ -198,18 +188,27 @@ class Mamba2RMSNorm(nn.Module): N = x.shape[-1] weight = weight.float() bias = bias.float() if bias is not None else None + args = get_args() if upcast: x = x.float() z = z.float() if z is not None else z if z is not None and not norm_before_gate: x = x * nn.functional.silu(z) if group_size is None: - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + if args.use_fused_rmsnorm: + out = torch_npu.npu_rms_norm(x, weight, epsilon=eps)[0] + else: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = x * rstd * weight + out = out + bias if bias is not None else out else: x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) - rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) - out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if args.use_fused_rmsnorm: + out = torch_npu.npu_rms_norm(x_group, weight.view(-1, group_size), epsilon=eps)[0] + else: + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) + out = x_group * rstd * weight.view(-1, group_size) + out = rearrange(out, "... g d -> ... (g d)") if bias is not None: out = out + bias if z is not None and norm_before_gate: diff --git a/mindspeed_llm/features_manager/models/mamba.py b/mindspeed_llm/features_manager/models/mamba.py index e9f062175..ac482e581 100644 --- a/mindspeed_llm/features_manager/models/mamba.py +++ b/mindspeed_llm/features_manager/models/mamba.py @@ -18,11 +18,11 @@ class MambaModel(MindSpeedFeature): group.add_argument('--mamba-headdim', type=int, default=80, help='head dim for mamba') def register_patches(self, patch_manager, args): - from mindspeed_llm.core.ssm.mamba_mixer import mamba_mixer_init_wrapper, mamba_mixer_forward + from mindspeed_llm.core.ssm.mamba_mixer import mamba_mixer_init_wrapper, mamba_mixer_forward, Mamba2RMSNorm from mindspeed_llm.core.ssm.mamba_block import mamba_block_forward patch_manager.register_patch( - 'mamba_ssm.ops.triton.layernorm_gated.RMSNorm', + 'mamba_ssm.ops.triton.layernorm_gated.RMSNorm', Mamba2RMSNorm, create_dummy=True) patch_manager.register_patch( 'mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined', diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 504abc662..01bf9e88c 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -106,6 +106,11 @@ def get_batch(data_iterator): args = get_args() + is_middle_stage = not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) + pretrain_not_tnd_flags = not args.is_instruction_dataset and not args.reset_position_ids + if pretrain_not_tnd_flags and is_middle_stage: + return (None,) * 5 + # get batches based on the TP rank you are on batch, actual_seq_len = get_batch_on_this_tp_rank(data_iterator) diff --git a/pretrain_mamba.py b/pretrain_mamba.py index 6888e4402..926799eaf 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -94,7 +94,13 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel: def get_batch(data_iterator): """Generate a batch.""" + args = get_args() + is_middle_stage = not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) + pretrain_not_tnd_flags = not args.is_instruction_dataset and not args.reset_position_ids + if pretrain_not_tnd_flags and is_middle_stage: + return (None,) * 5 + # get batches based on the TP rank you are on batch, actual_seq_len = get_batch_on_this_tp_rank(data_iterator) args = get_args() @@ -110,7 +116,7 @@ def get_batch(data_iterator): batch.pop('idx', None) if args.reset_position_ids: - generate_actual_seq_len(batch) + generate_actual_seq_len(batch, actual_seq_len) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch.values() @@ -173,9 +179,7 @@ def forward_step(data_iterator, model: MambaModel): def is_dataset_built_on_rank(): - return ( - mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() - ) and mpu.get_tensor_model_parallel_rank() == 0 + return mpu.get_tensor_model_parallel_rank() == 0 def core_gpt_dataset_config_from_args(args): -- Gitee