diff --git a/docs/pytorch/models/ssm_model.md b/docs/pytorch/models/ssm_model.md
index d3ce95f20e7af5820ed8e06ac8dd4437f70c1d68..155fa8fdfea7b54383d7a4b18cf66f877d9bd4d5 100644
--- a/docs/pytorch/models/ssm_model.md
+++ b/docs/pytorch/models/ssm_model.md
@@ -17,7 +17,7 @@
| Mamba2 |
- 2.7B |
+ 2.7B |
mamba2 |
4K |
Mcore |
@@ -47,4 +47,9 @@
## 以上模型脚本环境变量声明:
-关于脚本的环境变量定义见[environment_variable.md](../features/environment_variable.md)。
\ No newline at end of file
+HCCL_CONNECT_TIMEOUT:设置HCCL超时时间,默认值为120
+CUDA_DEVICE_MAX_CONNECTIONS:定义了任务流能够利用或映射到的硬件队列的数量
+PYTORCH_NPU_ALLOC_CONF:内存碎片优化开关,默认是expandable_segments:False,使能时expandable_segments:True
+NPUS_PER_NODE: 配置一个计算节点上使用的NPU数量
+CPU_AFFINITY_CONF: cpu绑核环境变量
+TASK_QUEUE_ENABLE:二级流水下发环境变量
diff --git a/examples/mcore/mamba2/pretrain_mamba2_2.7b_4k_ptd.sh b/examples/mcore/mamba2/pretrain_mamba2_2.7b_4k_ptd.sh
index af05b88574ecf8ce5ff01f9bb4f67e1b988361df..1a7b48c434da27665b2e7e1874ff7b07c08cb7b8 100644
--- a/examples/mcore/mamba2/pretrain_mamba2_2.7b_4k_ptd.sh
+++ b/examples/mcore/mamba2/pretrain_mamba2_2.7b_4k_ptd.sh
@@ -3,6 +3,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
export CPU_AFFINITY_CONF=1
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export HCCL_CONNECT_TIMEOUT=3600
+export TASK_QUEUE_ENABLE=2
NPUS_PER_NODE=8
MASTER_ADDR=localhost
diff --git a/examples/mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh b/examples/mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh
index 8116dc62563536b40d7499090abb5f55217c3806..618ca64fd123b6a843e71abbe421bac429da81c1 100644
--- a/examples/mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh
+++ b/examples/mcore/mamba2/pretrain_mamba2_8b_4k_ptd.sh
@@ -3,6 +3,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
export CPU_AFFINITY_CONF=1
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export HCCL_CONNECT_TIMEOUT=3600
+export TASK_QUEUE_ENABLE=2
NPUS_PER_NODE=8
MASTER_ADDR=localhost
diff --git a/examples/mcore/mamba2/pretrain_mamba2_hybrid_8b_4k_ptd.sh b/examples/mcore/mamba2/pretrain_mamba2_hybrid_8b_4k_ptd.sh
index e640b1f63ef7f35c86a192cbb587855b155f41d2..307b091cc5fea455e304345263254675b770b5aa 100644
--- a/examples/mcore/mamba2/pretrain_mamba2_hybrid_8b_4k_ptd.sh
+++ b/examples/mcore/mamba2/pretrain_mamba2_hybrid_8b_4k_ptd.sh
@@ -3,6 +3,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
export CPU_AFFINITY_CONF=1
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export HCCL_CONNECT_TIMEOUT=3600
+export TASK_QUEUE_ENABLE=2
NPUS_PER_NODE=8
MASTER_ADDR=localhost
diff --git a/mindspeed_llm/core/ssm/mamba_mixer.py b/mindspeed_llm/core/ssm/mamba_mixer.py
index 60ed5c7171e87a6e2a14bccf7e1a379a9c915799..89b4b01b25c7db11e2a754d5a99fce6d1633837c 100644
--- a/mindspeed_llm/core/ssm/mamba_mixer.py
+++ b/mindspeed_llm/core/ssm/mamba_mixer.py
@@ -7,6 +7,7 @@ from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
+import torch_npu
from megatron.training import get_args
from mindspeed_llm.tasks.models.ssm.state_space_duality import StateSpaceProcessor, ProcessInputs, StateOptions
@@ -15,33 +16,20 @@ def mamba_mixer_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
param_args = get_args()
- kwargs["rmsnorm"] = False
kwargs["chunk_size"] = param_args.mamba_chunk_size
kwargs["d_state"] = param_args.mamba_d_state
kwargs["d_conv"] = param_args.mamba_d_conv
kwargs["expand"] = param_args.mamba_expand
kwargs["headdim"] = param_args.mamba_headdim
fn(self, *args, **kwargs)
- self.rmsnorm = True
dt_min = kwargs.pop('dt_min', 0.001)
dt_max = kwargs.pop('dt_max', 0.1)
- args = get_args()
self.use_mem_eff_path = False
self.d_ssm = param_args.mamba_d_ssm
self.dt_min = dt_min
self.dt_max = dt_max
self.d_ssm_local = self.d_inner_local if self.d_ssm is None else self.d_ssm // self.tensor_model_parallel_size
-
- if self.rmsnorm:
- self.norm = Mamba2RMSNorm(
- self.d_inner_local,
- eps=1e-5,
- group_size=self.d_inner_local // self.ngroups_local,
- norm_before_gate=self.norm_before_gate,
- device=torch.cuda.current_device(),
- dtype=self.config.params_dtype
- )
return wrapper
@@ -176,7 +164,7 @@ def mamba_mixer_forward(self, hidden_states, seqlen=None, seq_idx=None, cu_seqle
class Mamba2RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None, sequence_parallel: bool = True):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
@@ -190,6 +178,8 @@ class Mamba2RMSNorm(nn.Module):
self.norm_before_gate = norm_before_gate
self.reset_parameters()
+ setattr(self.weight, 'sequence_parallel', sequence_parallel)
+
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
@@ -198,18 +188,27 @@ class Mamba2RMSNorm(nn.Module):
N = x.shape[-1]
weight = weight.float()
bias = bias.float() if bias is not None else None
+ args = get_args()
if upcast:
x = x.float()
z = z.float() if z is not None else z
if z is not None and not norm_before_gate:
x = x * nn.functional.silu(z)
if group_size is None:
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
- out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
+ if args.use_fused_rmsnorm:
+ out = torch_npu.npu_rms_norm(x, weight, epsilon=eps)[0]
+ else:
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
+ out = x * rstd * weight
+ out = out + bias if bias is not None else out
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
- rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
- out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
+ if args.use_fused_rmsnorm:
+ out = torch_npu.npu_rms_norm(x_group, weight.view(-1, group_size), epsilon=eps)[0]
+ else:
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
+ out = x_group * rstd * weight.view(-1, group_size)
+ out = rearrange(out, "... g d -> ... (g d)")
if bias is not None:
out = out + bias
if z is not None and norm_before_gate:
diff --git a/mindspeed_llm/features_manager/models/mamba.py b/mindspeed_llm/features_manager/models/mamba.py
index e9f062175b06189b102a1aef2645cbb4fd2c4069..ac482e58143fa4f15d1abf9e8d183698d54b8fe5 100644
--- a/mindspeed_llm/features_manager/models/mamba.py
+++ b/mindspeed_llm/features_manager/models/mamba.py
@@ -18,11 +18,11 @@ class MambaModel(MindSpeedFeature):
group.add_argument('--mamba-headdim', type=int, default=80, help='head dim for mamba')
def register_patches(self, patch_manager, args):
- from mindspeed_llm.core.ssm.mamba_mixer import mamba_mixer_init_wrapper, mamba_mixer_forward
+ from mindspeed_llm.core.ssm.mamba_mixer import mamba_mixer_init_wrapper, mamba_mixer_forward, Mamba2RMSNorm
from mindspeed_llm.core.ssm.mamba_block import mamba_block_forward
patch_manager.register_patch(
- 'mamba_ssm.ops.triton.layernorm_gated.RMSNorm',
+ 'mamba_ssm.ops.triton.layernorm_gated.RMSNorm', Mamba2RMSNorm,
create_dummy=True)
patch_manager.register_patch(
'mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined',
diff --git a/pretrain_gpt.py b/pretrain_gpt.py
index 504abc66258347371b1ee27798753fb58cfc661e..01bf9e88c5ccc7471e6962869e12a99d2bb2de82 100644
--- a/pretrain_gpt.py
+++ b/pretrain_gpt.py
@@ -106,6 +106,11 @@ def get_batch(data_iterator):
args = get_args()
+ is_middle_stage = not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage())
+ pretrain_not_tnd_flags = not args.is_instruction_dataset and not args.reset_position_ids
+ if pretrain_not_tnd_flags and is_middle_stage:
+ return (None,) * 5
+
# get batches based on the TP rank you are on
batch, actual_seq_len = get_batch_on_this_tp_rank(data_iterator)
diff --git a/pretrain_mamba.py b/pretrain_mamba.py
index 6888e4402fe5a4a6bea2298f440aba91e3bd330a..926799eaf3df635298da3ef15e37b148a743c6b4 100644
--- a/pretrain_mamba.py
+++ b/pretrain_mamba.py
@@ -94,7 +94,13 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel:
def get_batch(data_iterator):
"""Generate a batch."""
+ args = get_args()
+ is_middle_stage = not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage())
+ pretrain_not_tnd_flags = not args.is_instruction_dataset and not args.reset_position_ids
+ if pretrain_not_tnd_flags and is_middle_stage:
+ return (None,) * 5
+
# get batches based on the TP rank you are on
batch, actual_seq_len = get_batch_on_this_tp_rank(data_iterator)
args = get_args()
@@ -110,7 +116,7 @@ def get_batch(data_iterator):
batch.pop('idx', None)
if args.reset_position_ids:
- generate_actual_seq_len(batch)
+ generate_actual_seq_len(batch, actual_seq_len)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
return batch.values()
@@ -173,9 +179,7 @@ def forward_step(data_iterator, model: MambaModel):
def is_dataset_built_on_rank():
- return (
- mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
- ) and mpu.get_tensor_model_parallel_rank() == 0
+ return mpu.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):