From 213bafb8676f34267dc2e8af881d9743aefb0c0f Mon Sep 17 00:00:00 2001 From: kingsleynadher Date: Thu, 26 Jun 2025 14:14:09 +0800 Subject: [PATCH] [pytorch][feature] pytorch model mamba cp algorithm support --- .../features/mamba_context_parallel.md | 53 ++ mindspeed_llm/core/ssm/mamba_mixer.py | 64 ++- .../ssm/state_space_context_parallel.py | 536 ++++++++++++++++++ .../tasks/models/ssm/state_space_duality.py | 111 +++- mindspeed_llm/training/arguments.py | 4 +- mindspeed_llm/training/utils.py | 2 +- tests/pipeline/mamba2/mamba2_2.7b_tp1_pp1.sh | 4 + 7 files changed, 743 insertions(+), 31 deletions(-) create mode 100644 docs/pytorch/features/mamba_context_parallel.md create mode 100644 mindspeed_llm/tasks/models/ssm/state_space_context_parallel.py diff --git a/docs/pytorch/features/mamba_context_parallel.md b/docs/pytorch/features/mamba_context_parallel.md new file mode 100644 index 000000000..3d2fa4a5d --- /dev/null +++ b/docs/pytorch/features/mamba_context_parallel.md @@ -0,0 +1,53 @@ +# Mamba-CP + +## 背景 + +Mamba为解决transformer模型序列长度2次方复杂度提出,成为长序列训练的重要架构,在序列长度大幅增长时,激活值对显存压力大幅增长,仍然急需CP大幅降低超长序列带来的显存压力,当前外部Mamba开源框架CP仍处于空白; + +## 问题 + +在Mamba的SSM递归运算步骤中存在时间依赖关系,传统CP必须等上一CP rank运算完毕将结果传递到下一CP rank方可执行下一步运算,引入空闲等待,设计了一种并行Mamba-CP方案使得所有rank可以并发执行状态传递计算,相对传统CP性能大幅提升; + +传统CP见Mamba-2 paper Figure 5 + +## 解决方案 + +针对存在时间依赖关系的状态传递部分,对各个CP rank中local_decay及local_state进行AllGather,使所有CP rank可以并发执行状态传递计算,同时还对前向的AllGather和反向的ReduceScatter进行了计算通信掩盖; + +## 使用场景 + +1. 与TPSP正交,可以在开启TP的基础上进一步开启CP降低显存; +2. TP有n_groups整除限制,CP无限制; +3. CP在显存不足场景,开启CP相比开启重计算降低显存方式性能更优; + + +## 使用方法 + +| 重要参数 | 参数说明 | +|---------------------------------------|-----------------------------------------------------------------| +| --context-parallel-algo mamba_cp_algo | 长序列并行算法选项,默认项为`ulysses_cp_algo`,当设置为`mamba_cp_algo`时开启Mamba-CP。 | +| --context-parallel-size [int] | 开启CP对应的数量,默认为1,根据用户需求配置。 | + + +## 使用效果 + +节省显存方式:重计算、额外开CP等;重计算为常见方式,但会引入额外30%耗时,在长序列、显存受限场景通常寻求如何去组合特性,在不超出硬件显存的前提下尽可能提升性能,在相同显存占用基础上,对开启CP节省显存和开启重计算方式节省显存进行了性能对比如下: + +CP开启前后显存优化及性能变化: + +| 序列长度 | 并行配置 | 显存占用 | 内存优化 | 性能 | 性能变化 | +|------|--------|---------|------|----------|--------| +| 32K | TP4CP1 | 56129MB | —— | 3761.1ms | —— | +| 32K | TP4CP2 | 32613MB | 42% | 3862.3ms | -2.69% | + +CP与重计算缩减显存+性能对比: + +| 序列长度 | 并行配置 | 显存占用 | 性能 | 加速比例 | +|------|----------------|----------|-----------|-----------| +| 32K | TP4CP1 + 全重计算 | 同等显存30G | 4728.8ms | —— | +| 32K | TP4CP2 | 同等显存30G | 3862.3ms | +22.43% | + + +## 注意事项: + +1. 在面临Mamba-CP场景需要省显存情况下,优先开启CP,然后再开重计算; \ No newline at end of file diff --git a/mindspeed_llm/core/ssm/mamba_mixer.py b/mindspeed_llm/core/ssm/mamba_mixer.py index 89b4b01b2..964c2bd5c 100644 --- a/mindspeed_llm/core/ssm/mamba_mixer.py +++ b/mindspeed_llm/core/ssm/mamba_mixer.py @@ -1,3 +1,4 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import math from copy import deepcopy from dataclasses import dataclass @@ -9,7 +10,10 @@ import torch.nn as nn import torch.nn.functional as F import torch_npu from megatron.training import get_args +from megatron.core import mpu + from mindspeed_llm.tasks.models.ssm.state_space_duality import StateSpaceProcessor, ProcessInputs, StateOptions +from mindspeed_llm.tasks.models.ssm.state_space_context_parallel import SequenceParallelConvFunction def mamba_mixer_init_wrapper(fn): @@ -38,6 +42,10 @@ def mamba_mixer_forward(self, hidden_states, seqlen=None, seq_idx=None, cu_seqle hidden_states: (nL, B, D) / (L B D) Returns: same shape as hidden_states """ + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + cp_group = mpu.get_context_parallel_group() + seqlen_og = seqlen if seqlen is None: seqlen, batch, dim = hidden_states.shape @@ -77,27 +85,43 @@ def mamba_mixer_forward(self, hidden_states, seqlen=None, seq_idx=None, cu_seqle ], dim=-1, ) - - # transpose: b l pd --> b pd l - xBC = rearrange(xBC, "b l d -> b d l").contiguous() - - # Compute short convolution - if conv_state is not None: - if cu_seqlens: + if cp_size > 1: + xBC, dt = SequenceParallelConvFunction.apply( + xBC, # Input xBC for current rank + dt, # Input dt for current rank + self.conv1d.weight, + self.conv1d.bias, + self.dt_bias, + cp_group, + cp_size, + cp_rank, + self.d_conv, # kernel_size + self.nheads_local, + self.d_inner_local, # For splitting xBC_processed later + self.d_state, # For splitting xBC_processed later + self.ngroups_local # For splitting xBC_processed later + ) + else: + # transpose: b l pd --> b pd l + xBC = rearrange(xBC, "b l d -> b d l").contiguous() + + # Compute short convolution + if conv_state is not None: + if cu_seqlens: + raise('Variable length inputs in convolution are not currently supported') + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_( + F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) + ) # Update state (B D W) + + seqlen = xBC.size(2) + if seq_idx: raise('Variable length inputs in convolution are not currently supported') - # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state.copy_( - F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) - ) # Update state (B D W) - - seqlen = xBC.size(2) - if seq_idx: - raise('Variable length inputs in convolution are not currently supported') - xBC = self.act(self.conv1d(xBC)[..., :seqlen]) - - # transpose b pd l --> b l pd - xBC = rearrange(xBC, "b d l -> b l d").contiguous() + xBC = self.act(self.conv1d(xBC)[..., :seqlen]) + + # transpose b pd l --> b l pd + xBC = rearrange(xBC, "b d l -> b l d").contiguous() x, B, C = torch.split( xBC, diff --git a/mindspeed_llm/tasks/models/ssm/state_space_context_parallel.py b/mindspeed_llm/tasks/models/ssm/state_space_context_parallel.py new file mode 100644 index 000000000..022389cac --- /dev/null +++ b/mindspeed_llm/tasks/models/ssm/state_space_context_parallel.py @@ -0,0 +1,536 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + + +def allgather_async(tensor: torch.Tensor, + group: dist.ProcessGroup, + size: int): + """ + Gather `tensor` from every rank in `group` asynchronously, returns + work handler and buffer with `tensor` from each rank + """ + if size == 1: + return None, tensor + + tensor = tensor.contiguous() + gather_buf = torch.empty( + (size, *tensor.shape), dtype=tensor.dtype, device=tensor.device + ) + + work = dist.all_gather_into_tensor( + gather_buf, tensor, group=group, async_op=True + ) + return work, gather_buf + + +class SequenceParallelConvFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, + xBC_curr_rank, + dt_curr_rank, + conv1d_weight, + conv1d_bias, + dt_bias, + cp_group, + cp_size, + cp_rank, + kernel_sz, + nheads, + d_inner, + d_state, + ngroups): + """ + Implements the forward pass for a sequence-parallel 1D convolution. + + This function handles the communication required for causal convolutions + across different ranks in a context-parallel group. It overlaps the + asynchronous communication (`all_gather` of convolution "tails") with + independent computations (processing of the dt tensor) to improve + performance. + + Args: + ctx: The context object for `torch.autograd.Function` to save tensors for backward. + xBC_curr_rank (Tensor): The local shard of the main input tensor for the current rank. + Shape: [B, L_local, D_xBC], where B is batch size, L_local is the + local sequence length, and D_xBC is the feature dimension for + the combined x, B, and C tensors of the SSM. + dt_curr_rank (Tensor): The local shard of the timestep tensor (delta t, or Δt). + Shape: [B, L_local, nheads]. + conv1d_weight (Tensor): The weight parameter of the nn.Conv1d layer. + Shape: [channels, 1, kernel_size]. + conv1d_bias (Tensor): The bias parameter of the nn.Conv1d layer. Can be None. + Shape: [channels]. + dt_bias (Tensor): A learnable bias parameter added to the dt tensor. + Shape: [nheads]. + cp_group (ProcessGroup): The process group for context-parallel communication. + cp_size (int): The world size of the context-parallel group. + cp_rank (int): The rank of the current process in the context-parallel group. + kernel_sz (int): The kernel size of the 1D convolution (d_conv). + nheads (int): Number of attention heads. + d_inner (int): The inner dimension of the Mamba block. + d_state (int): The state dimension (N) of the SSM. + ngroups (int): Number of groups for B and C parameters. + + Returns: + xBC_processed (Tensor): The output tensor for the current rank after the sequence-parallel + convolution and SiLU activation. Shape: [B, L_local, D_xBC]. + dt_processed (Tensor): The processed dt tensor after adding the bias and applying softplus. + Shape: [B, L_local, nheads]. + """ + + + tail_len = kernel_sz - 1 + + # 1. Prepare local_tail for all_gather + local_tail_for_ag = None # Initialize + if cp_size > 1 and tail_len > 0: + local_tail_for_ag = xBC_curr_rank[:, -tail_len:, :].contiguous() + elif cp_size > 1 and tail_len == 0: # Should not happen with kernel_sz > 1 + # Create an empty tensor with correct batch and feature dims, but 0 seq_len + local_tail_for_ag = torch.empty((xBC_curr_rank.shape[0], 0, xBC_curr_rank.shape[2]), + dtype=xBC_curr_rank.dtype, device=xBC_curr_rank.device) + + # 2. Initiate Asynchronous AllGather for local_tail_for_ag + ag_work_handle, ag_buf = None, None + if tail_len > 0: # Only do all_gather if there's actually a tail + ag_work_handle, ag_buf = allgather_async(local_tail_for_ag, cp_group, cp_size) + elif cp_size > 1: # tail_len is 0 but cp_size > 1 + ag_work_handle, ag_buf = allgather_async(local_tail_for_ag, cp_group, cp_size) + + # 3. Perform Computation A (independent of all_gather result - dt processing) + dt_contiguous = dt_curr_rank.contiguous() + dt_plus_bias = dt_contiguous + dt_bias + dt_processed = F.softplus(dt_plus_bias) + + # 4. Wait for all_gather to complete (if it was launched) + if ag_work_handle: + ag_work_handle.wait() + + # 5. Prepare input for convolution using gathered tails + prev_tail_data = None + if tail_len > 0: + if cp_size == 1: + prev_tail_data = torch.zeros_like(local_tail_for_ag) # Padded with zeros + elif cp_rank == 0: + prev_tail_data = torch.zeros_like(local_tail_for_ag) # Padded with zeros + else: + prev_tail_data = ag_buf[cp_rank - 1] + conv_input = torch.cat([prev_tail_data, xBC_curr_rank], dim=1) + else: + conv_input = xBC_curr_rank + + # 6. Perform Convolution + conv_input_transposed = conv_input.transpose(1, 2).contiguous() + padding_val = 0 + conv_output_transposed = F.conv1d( + conv_input_transposed, + conv1d_weight, + conv1d_bias, + stride=1, + padding=padding_val, + dilation=1, + groups=conv1d_weight.shape[0] + ) + conv_output_full = conv_output_transposed.transpose(1, 2) + xBC_conv_sliced = conv_output_full + xBC_processed = F.silu(xBC_conv_sliced.contiguous()) + + # Save tensors for backward + ctx.save_for_backward( + xBC_curr_rank, dt_contiguous, local_tail_for_ag, ag_buf, prev_tail_data, conv_input, + conv1d_weight, conv1d_bias, dt_bias, dt_plus_bias, xBC_conv_sliced + ) + # Save non-tensor attributes + ctx.cp_group = cp_group + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank + ctx.kernel_sz = kernel_sz + ctx.tail_len = tail_len + ctx.padding_val = padding_val + ctx.conv_groups = conv1d_weight.shape[0] + ctx.xBC_curr_rank_seqlen = xBC_curr_rank.shape[1] + ctx.xBC_curr_rank_requires_grad = xBC_curr_rank.requires_grad + ctx.conv1d_weight_requires_grad = conv1d_weight.requires_grad + ctx.conv1d_bias_requires_grad = conv1d_bias is not None and conv1d_bias.requires_grad + + ctx.dt_curr_rank_requires_grad = dt_curr_rank.requires_grad + ctx.dt_bias_requires_grad = dt_bias.requires_grad + + return xBC_processed, dt_processed + + @staticmethod + def backward(ctx, grad_xBC_processed, grad_dt_processed): + ( + xBC_curr_rank, dt_contiguous_saved, local_tail_saved, ag_buf_saved, prev_tail_saved, conv_input_saved, + conv1d_weight_saved, conv1d_bias_saved, dt_bias_saved, dt_plus_bias_saved, xBC_conv_sliced_saved + ) = ctx.saved_tensors + + # Backward for Computation B (Convolution and SiLU path) + + # 1. SiLU backward via official operator for perfect alignment + x_slice = xBC_conv_sliced_saved.contiguous() + grad_silu = torch.ops.aten.silu_backward( + grad_xBC_processed.contiguous(), + x_slice + )[0] + + # 2. Reconstruct full conv1d output gradient + B, L_in, C = conv_input_saved.shape + L_out = L_in - ctx.tail_len + grad_conv_out = conv_input_saved.new_zeros((B, L_out, C)) + grad_conv_out[:, :, :] = grad_silu + + # 3. Convert to (B, C, L_out) for conv1d grad + grad_out_tc = grad_conv_out.transpose(1, 2).contiguous() # (B, C_out, L_out) + + # 4. Compute grad w.r.t. conv input & weight via high-level API + inp_tc = conv_input_saved.transpose(1, 2).contiguous() # (B, C_in, L_in) + grad_conv_input_tc = torch.nn.grad.conv1d_input( + input_size=inp_tc.shape, + weight=conv1d_weight_saved, + grad_output=grad_out_tc, + stride=1, + padding=ctx.padding_val, + dilation=1, + groups=ctx.conv_groups, + ) + grad_conv1d_weight_val = torch.nn.grad.conv1d_weight( + input=inp_tc, + weight_size=conv1d_weight_saved.shape, + grad_output=grad_out_tc, + stride=1, + padding=ctx.padding_val, + dilation=1, + groups=ctx.conv_groups, + ) + grad_conv1d_bias_val = None + if conv1d_bias_saved is not None and ctx.conv1d_bias_requires_grad: + grad_conv1d_bias_val = grad_out_tc.sum(dim=(0, 2)) # sum over batch & length + + # 5. Convert grad_conv_input back to (B, L_in, C) + grad_conv_input = grad_conv_input_tc.transpose(1, 2) + + # 6. Split into prefix tail and main body + if ctx.tail_len > 0: + grad_prev_tail = grad_conv_input[:, :ctx.tail_len, :] + grad_xBC_from_conv = grad_conv_input[:, ctx.tail_len:, :] + else: + grad_prev_tail = torch.empty((B, 0, C), device=grad_conv_input.device, dtype=grad_conv_input.dtype) + grad_xBC_from_conv = grad_conv_input + + # Gradients for AllGather input (local_tail_saved) via ReduceScatter + grad_local_tail_scattered = None + rs_handle = None + if ctx.cp_size > 1 and ctx.tail_len > 0 and ctx.xBC_curr_rank_requires_grad: + # prepare grad_buf of shape (cp_size, B, tail_len, C) + grad_buf = torch.zeros_like(ag_buf_saved, dtype=grad_prev_tail.dtype) + if ctx.cp_rank > 0: + grad_buf[ctx.cp_rank - 1] = grad_prev_tail + + grad_local_tail_scattered = torch.empty_like(local_tail_saved) + if grad_local_tail_scattered.numel() > 0: + rs_handle = dist.reduce_scatter_tensor( + output=grad_local_tail_scattered, + input=grad_buf, + op=dist.ReduceOp.SUM, + group=ctx.cp_group, + async_op=True + ) + + # Gradients for Computation A (dt_processed path) + grad_dt_plus_bias = grad_dt_processed * torch.sigmoid(dt_plus_bias_saved) + grad_dt_curr = grad_dt_plus_bias + grad_dt_bias = None + if ctx.dt_bias_requires_grad: + dims = list(range(grad_dt_plus_bias.ndim - dt_bias_saved.ndim)) + summed = grad_dt_plus_bias.sum(dim=dims) + grad_dt_bias = summed.reshape(dt_bias_saved.shape) + + # 7. Wait for scatter to finish + if rs_handle is not None: + rs_handle.wait() + + # 8. Combine xBC gradients + grad_xBC_total = None + if ctx.xBC_curr_rank_requires_grad: + grad_xBC_total = torch.zeros_like(xBC_curr_rank) + grad_xBC_total += grad_xBC_from_conv + if grad_local_tail_scattered is not None and grad_local_tail_scattered.numel() > 0: + actual = min(ctx.tail_len, xBC_curr_rank.shape[1]) + if actual > 0: + part = grad_local_tail_scattered[:, -actual:, :] + grad_xBC_total[:, -actual:, :] += part + + # 9. Zero out grads for non-requires + final_dt = grad_dt_curr + final_w = grad_conv1d_weight_val + final_b = grad_conv1d_bias_val + final_db = grad_dt_bias + + return ( + grad_xBC_total, + final_dt, + final_w, + final_b, + final_db, + None, None, None, None, None, None, None, None + ) + + +class SSDAllgatherOverlapFn(torch.autograd.Function): + @staticmethod + def forward(ctx, + local_decay, + local_hidden_state, + C_ssd_chunked_b, + B_ssd_chunked_b, + x_ssd_chunked, + A_ssd_reshaped, + A_cumsum_ssd, + cp_group, + cp_size, + segsum_fn_ref, + device_for_ops): + """ + Implements the forward pass for the core sequence-parallel SSM computation. + + This function orchestrates the communication-computation overlap central to the + sequence-parallel Structured State Space (SSD) algorithm. It initiates an + asynchronous all-gather of the per-rank state contributions (local_decay and + local_hidden_state) across the context-parallel group. While this + communication is in flight, it computes the independent, intra-chunk (diagonal) + part of the SSM output, Y_diag. + + Args: + ctx: The context object for `torch.autograd.Function` to save tensors for backward. + local_decay (Tensor): The per-rank decay factor (Λ_r). It represents how much the + hidden state decays over the entire sequence segment processed + by the current rank. Shape: [B, H]. + local_hidden_state (Tensor): The per-rank hidden state contribution (H_r). It is the state + accumulated over the current rank's sequence, assuming a + zero input state. Shape: [B, H, P, N]. + C_ssd_chunked_b (Tensor): The pre-transformed (permuted and reshaped) C parameter tensor + for the current rank's chunks, optimized for BMM. + Shape: [B*H*C_chunks, L_chunk, N]. + B_ssd_chunked_b (Tensor): The pre-transformed (permuted, reshaped, and transposed) B + parameter tensor for the current rank's chunks. + Shape: [B*H*C_chunks, N, L_chunk]. + x_ssd_chunked (Tensor): The chunked and discretized input tensor x for the current rank. + Shape: [B, C_chunks, L_chunk, H, P]. + A_ssd_reshaped (Tensor): The chunked and reshaped Δt * A tensor, used to compute the + L_ij matrix for the Y_diag calculation. + Shape: [B, H, C_chunks, L_chunk]. + A_cumsum_ssd (Tensor): The cumulative sum of A_ssd_reshaped. Used to compute the + intra-chunk decay for the Y_off calculation. + Shape: [B, H, C_chunks, L_chunk]. + cp_group (ProcessGroup): The process group for context-parallel communication. + cp_size (int): The world size of the context-parallel group. + segsum_fn_ref (function): A reference to the function for computing the segmented sum + (used to create the L_ij matrix). + device_for_ops (torch.device): The target device for creating new tensors if needed. + + Returns: + ld_buf (Tensor): The buffer containing `local_decay` gathered from all ranks. + Shape: [cp_size, B, H]. + lhs_buf (Tensor): The buffer containing `local_hidden_state` gathered from all ranks. + Shape: [cp_size, B, H, P, N]. + Y_diag_computed (Tensor): The diagonal component (Y_diag) of the SSM output for the + current rank. Shape: [B, C_chunks, L_chunk, H, P]. + state_decay_out_computed (Tensor): The intra-chunk state decay factor (exp(A_cumsum)), + used for the Y_off calculation. + Shape: [B, H, C_chunks, L_chunk]. + """ + + # 1. Initiate Asynchronous AllGathers + work_ld, ld_buf = allgather_async(local_decay, cp_group, cp_size) + work_lhs, lhs_buf = allgather_async(local_hidden_state, cp_group, cp_size) + + # 2. Perform "Independent" Forward Computations + s_val_for_L = segsum_fn_ref(A_ssd_reshaped) + L_val = torch.exp(s_val_for_L) + + # Start of Y_diag_computed calculation using bmm approach + B_dim = x_ssd_chunked.shape[0] + C_chunks_dim = x_ssd_chunked.shape[1] + L_chunk_dim = x_ssd_chunked.shape[2] + H_heads_dim = x_ssd_chunked.shape[3] + N_state_dim = B_ssd_chunked_b.shape[1] + P_headdim_dim = x_ssd_chunked.shape[4] + + # 1. Permute x_ssd_chunked to align with ssdOrigin's C_r, B_r, x_r + x_r_like = x_ssd_chunked.permute(0, 3, 1, 2, 4).contiguous() # (B, H_heads, C_chunks, L_chunk, P_headdim) + + # 2. Reshape for batch matrix multiplication (bmm) + x_b_like = x_r_like.reshape(-1, L_chunk_dim, P_headdim_dim) + L_b_like = L_val.to(torch.bfloat16).reshape(-1, L_val.shape[3], L_val.shape[4]) + + # 3. Perform bmm operations as in ssdOrigin + CB_b_like = torch.bmm(C_ssd_chunked_b, B_ssd_chunked_b) + CBL_b_like = (CB_b_like * L_b_like).to(torch.bfloat16) + Y_diag_intermediate = torch.bmm(CBL_b_like, x_b_like) + + # 4. Reshape and permute Y_diag_intermediate to the final target shape + Y_diag_computed_temp = Y_diag_intermediate.reshape( + B_dim, H_heads_dim, C_chunks_dim, L_chunk_dim, P_headdim_dim + ) + # Permute to target shape (B, C_chunks, L_chunk, H_heads, P_headdim) + Y_diag_computed = Y_diag_computed_temp.permute(0, 2, 3, 1, 4).contiguous() + + # End of Y_diag_computed calculation using bmm approach + + state_decay_out_computed = torch.exp(A_cumsum_ssd).to(torch.bfloat16).contiguous() + + # 3. Wait for AllGathers + if work_ld: + work_ld.wait() + if work_lhs: + work_lhs.wait() + + # Save tensors for backward pass + ctx.save_for_backward(local_decay, local_hidden_state, ld_buf, lhs_buf, + C_ssd_chunked_b, B_ssd_chunked_b, x_ssd_chunked, + A_ssd_reshaped, A_cumsum_ssd, + L_val, + s_val_for_L, + state_decay_out_computed + ) + # Save attributes + ctx.cp_group = cp_group + ctx.cp_size = cp_size + + # For manual segsum backward + ctx.T_for_segsum = A_ssd_reshaped.size(-1) + ctx.device_for_ops = device_for_ops + + return ld_buf, lhs_buf, Y_diag_computed, state_decay_out_computed + + + @staticmethod + def backward(ctx, + grad_ld_buf, grad_lhs_buf, + grad_Y_diag_computed, grad_state_decay_out_computed): + + (local_decay_saved, local_hidden_state_saved, ld_buf_saved, lhs_buf_saved, + C_b_s, B_b_s, x_ssd_chunked_saved, + A_ssd_reshaped_saved, A_cumsum_ssd_saved, + L_val_saved, s_val_for_L_saved, state_decay_out_saved + ) = ctx.saved_tensors + + # 1. Initiate Asynchronous ReduceScatters + grad_local_decay_rs_output = None + rs_handle_ld = None + if ctx.cp_size > 1: + if local_decay_saved.numel() > 0: + grad_local_decay_rs_output = torch.empty_like(local_decay_saved) + rs_handle_ld = dist.reduce_scatter_tensor(grad_local_decay_rs_output, grad_ld_buf, group=ctx.cp_group, async_op=True) + else: + grad_local_decay_rs_output = torch.empty_like(local_decay_saved) + else: + grad_local_decay_rs_output = grad_ld_buf.clone() if grad_ld_buf is not None else None + + grad_local_hidden_state_rs_output = None + rs_handle_lhs = None + if ctx.cp_size > 1: + if local_hidden_state_saved.numel() > 0: + grad_local_hidden_state_rs_output = torch.empty_like(local_hidden_state_saved) + rs_handle_lhs = dist.reduce_scatter_tensor(grad_local_hidden_state_rs_output, grad_lhs_buf, group=ctx.cp_group, async_op=True) + else: + grad_local_hidden_state_rs_output = torch.empty_like(local_hidden_state_saved) + else: + grad_local_hidden_state_rs_output = grad_lhs_buf.clone() if grad_lhs_buf is not None else None + + # 2. DIRECT Gradient Computations (Overlapped) + grad_C_ssd_chunked_b = None + grad_B_ssd_chunked_b = None + grad_x_ssd_chunked = None + grad_A_cumsum_ssd = None + grad_L_val_intermediate = None # d(Loss)/d(L_val) + + # Gradients for Y_diag_computed inputs (using bmm decomposition) + if grad_Y_diag_computed is not None: + gY_out = grad_Y_diag_computed # Shape: (B, C_chunks, L_chunk, H_heads, P_headdim) + + X_in_s = x_ssd_chunked_saved + L_in_s = L_val_saved + + # Get dimensions from saved tensors + B_dim = X_in_s.shape[0] + C_chunks_dim = X_in_s.shape[1] + L_chunk_dim = X_in_s.shape[2] + H_heads_dim = X_in_s.shape[3] + P_headdim_dim = X_in_s.shape[4] + B_eff = C_b_s.shape[0] # B * H * C_chunks + + # Inverse permutation (0,3,1,2,4) maps (B,C,L,H,P) back to (B,H,C,L,P) + gY_temp_reshaped = gY_out.permute(0, 3, 1, 2, 4) + + # Backward through Y_temp_reshaped = Y_inter_b.reshape(...) + gY_inter_b = gY_temp_reshaped.reshape(B_eff, L_chunk_dim, P_headdim_dim) + + # Recompute intermediates from forward pass needed for bmm backward + X_r_like_s = X_in_s.permute(0, 3, 1, 2, 4) + X_b_s = X_r_like_s.reshape(B_eff, L_chunk_dim, P_headdim_dim) + + L_b_s = L_in_s.to(torch.bfloat16).reshape(B_eff, L_chunk_dim, L_chunk_dim) + + CB_b_s = torch.bmm(C_b_s, B_b_s) # (B_eff, L_chunk, L_chunk) + CBL_b_s = (CB_b_s * L_b_s).to(torch.bfloat16) # (B_eff, L_chunk, L_chunk) + + # Backward through Y_inter_b = torch.bmm(CBL_b_s, X_b_s) + gCBL_b = torch.bmm(gY_inter_b, X_b_s.transpose(1, 2)) + gX_b = torch.bmm(CBL_b_s.transpose(1, 2), gY_inter_b).contiguous() + gX_r_like = gX_b.reshape(B_dim, H_heads_dim, C_chunks_dim, L_chunk_dim, P_headdim_dim) + grad_x_ssd_chunked = gX_r_like.permute(0, 2, 3, 1, 4).contiguous() + + # Backward through CBL_b_s = CB_b_s * L_b_s (element-wise) + gCB_b_from_CBL = None + if gCBL_b is not None: + gCB_b_from_CBL = (gCBL_b * L_b_s).to(torch.bfloat16) + gL_b = (gCBL_b * CB_b_s).to(torch.bfloat16).contiguous() + grad_L_val_intermediate = gL_b.reshape(B_dim, H_heads_dim, C_chunks_dim, L_chunk_dim, L_chunk_dim) + + # Backward through CB_b_s = torch.bmm(C_b_s, B_b_s) + if gCB_b_from_CBL is not None: + grad_C_ssd_chunked_b = torch.bmm(gCB_b_from_CBL, B_b_s.transpose(1, 2)) + grad_B_ssd_chunked_b = torch.bmm(C_b_s.transpose(1, 2), gCB_b_from_CBL) + + # Gradient for A_ssd_reshaped (from L_val's gradient) + grad_A_ssd_reshaped = None + if grad_L_val_intermediate is not None: + grad_s_val = (grad_L_val_intermediate * L_val_saved) + T_mask_dim = A_ssd_reshaped_saved.size(-1) + if s_val_for_L_saved.shape[-1] != T_mask_dim or s_val_for_L_saved.shape[-2] != T_mask_dim: + raise ValueError(f"T_mask_dim ({T_mask_dim}) mismatch with s_val_for_L_saved dimensions " + f"({s_val_for_L_saved.shape[-2]}, {s_val_for_L_saved.shape[-1]})") + mask_triu = torch.tril(torch.ones(T_mask_dim, T_mask_dim, dtype=torch.bool, device=ctx.device_for_ops), diagonal=0) + mask_triu_bc = mask_triu.view((1,) * (grad_s_val.ndim - 2) + (T_mask_dim, T_mask_dim)) + grad_x_sc_inter = grad_s_val.masked_fill(~mask_triu_bc, 0.0) + grad_x_masked_tril_dim_minus_2 = torch.cumsum(torch.flip(grad_x_sc_inter, dims=[-2]), dim=-2) + grad_x_masked_tril = torch.flip(grad_x_masked_tril_dim_minus_2, dims=[-2]) + mask_tril = torch.tril(torch.ones(T_mask_dim, T_mask_dim, dtype=torch.bool, device=ctx.device_for_ops), diagonal=-1) + mask_tril_bc = mask_tril.view((1,) * (grad_x_masked_tril.ndim - 2) + (T_mask_dim, T_mask_dim)) + grad_x_rep = grad_x_masked_tril.masked_fill(~mask_tril_bc, 0.0) + grad_A_ssd_reshaped = grad_x_rep.sum(dim=-1) + + # Gradient for A_cumsum_ssd + if grad_state_decay_out_computed is not None: + grad_A_cumsum_ssd = (grad_state_decay_out_computed * state_decay_out_saved).to(torch.bfloat16) + + # 3. Wait for ReduceScatters + if rs_handle_ld: + rs_handle_ld.wait() + if rs_handle_lhs: + rs_handle_lhs.wait() + + return ( + grad_local_decay_rs_output, + grad_local_hidden_state_rs_output, + grad_C_ssd_chunked_b, + grad_B_ssd_chunked_b, + grad_x_ssd_chunked, + grad_A_ssd_reshaped, + grad_A_cumsum_ssd, + None, None, None, None + ) diff --git a/mindspeed_llm/tasks/models/ssm/state_space_duality.py b/mindspeed_llm/tasks/models/ssm/state_space_duality.py index 809eeffaf..33f26043a 100644 --- a/mindspeed_llm/tasks/models/ssm/state_space_duality.py +++ b/mindspeed_llm/tasks/models/ssm/state_space_duality.py @@ -3,6 +3,10 @@ from typing import Optional, Tuple import torch import torch.nn as nn from einops import rearrange, repeat +import torch.nn.functional as F +from megatron.core import mpu + +from mindspeed_llm.tasks.models.ssm.state_space_context_parallel import SSDAllgatherOverlapFn @dataclass @@ -67,6 +71,10 @@ class StateSpaceProcessor: Returns: y: (B, L, H, P) Output features """ + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + cp_group = mpu.get_context_parallel_group() + # Unpack inputs x, dt, A, B, C, D = inputs.x, inputs.dt, inputs.A, inputs.B, inputs.C, inputs.D @@ -78,18 +86,105 @@ class StateSpaceProcessor: # Dimension transformations x, dt, A, B, C = self._expand_dims(x, A, dt, B, C) B_exp, C_exp = self._expand_groups_to_heads(B, C) - dt_proc = self._process_time_step(dt) D = self._prepare_residual(D, x, pad_size) - # Chunk processing - x_pad, A_pad, B_pad, C_pad = self._chunk_and_pad(x, dt_proc, A, B_exp, C_exp, pad_size) + if cp_size == 1: + dt_proc = self._process_time_step(dt) + + # Chunk processing + x_pad, A_pad, B_pad, C_pad = self._chunk_and_pad(x, dt_proc, A, B_exp, C_exp, pad_size) + + # Core computations + Y_diag, states, A_cum, C_br = self._compute_diagonal_blocks(A_pad, B_pad, C_pad, x_pad) + Y_off, final_state = self._compute_inter_chunk_blocks(A_cum, C_br, states, initial_states) + + # Output synthesis + state_opts.final_state = final_state + + elif cp_size > 1: + self.config['dt_min'] = torch.tensor(self.config['dt_min'], dtype=torch.float32, device=dt.device) + self.config['dt_max'] = torch.tensor(self.config['dt_max'], dtype=torch.float32, device=dt.device) + dt_proc = torch.clamp(dt, self.config['dt_min'], self.config['dt_max']) + + x_c, A_c, B_c, C_c = self._chunk_and_pad(x, dt_proc, A, B_exp, C_exp, pad_size) + device = x_c.device + A_reshaped = rearrange(A_c, "b c l h -> b h c l").contiguous() # (B, H, C, L) + + A_cumsum = torch.cumsum(A_reshaped, dim=-1).contiguous() # (B, H, C, L) + + decay_states_arg = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum).to(torch.bfloat16) # (B,H,C,L) used with B_c, x_c + states_calc = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B_c, decay_states_arg, x_c) + + if initial_states is not None: + initial_states_loop_val = initial_states + else: + initial_states_loop_val = torch.zeros_like(states_calc[:, :1], dtype=torch.float32) # (B, 1, H, P, N) + + padded_A_cumsum_end = F.pad(A_cumsum[:, :, :, -1], (1, 0)) + decay_chunk_calc = torch.exp(self._segmented_sum(padded_A_cumsum_end.contiguous())) # (B,H,C+1,C+1) + + local_decay = decay_chunk_calc[:, :, -1, 0].contiguous() # (B, H) + + new_states_calc = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk_calc[:, :, 1:, 1:].to(torch.bfloat16), states_calc).contiguous() + # Split into two parts: the first C_chunks-1 elements, and the last 1 element. + C_chunks = new_states_calc.shape[1] + new_states_prefix, local_hidden_state_unsqeezed = torch.split( + new_states_calc, + [C_chunks - 1, 1], + dim=1 + ) + # We squeeze it to match the original shape (B, H, P, N). + local_hidden_state = local_hidden_state_unsqeezed.squeeze(1).contiguous() + + L_chunk_dim_ssp = C_c.shape[2] + N_state_dim_ssp = C_c.shape[4] + C_r = C_c.permute(0, 3, 1, 2, 4) + C_b = C_r.reshape(-1, L_chunk_dim_ssp, N_state_dim_ssp) + + B_r = B_c.permute(0, 3, 1, 2, 4) # (B, H_heads, C_chunks, L_chunk, N_state) + B_b = B_r.reshape(-1, L_chunk_dim_ssp, N_state_dim_ssp).transpose(1, 2) # (B*H*C_chunks, N_state, L_chunk) + + ld_buf, lhs_buf, Y_diag, state_decay_out = SSDAllgatherOverlapFn.apply( + local_decay, local_hidden_state, + C_b, B_b, x_c, A_reshaped, A_cumsum, + cp_group, cp_size, + self._segmented_sum, device + ) - # Core computations - Y_diag, states, A_cum, C_br = self._compute_diagonal_blocks(A_pad, B_pad, C_pad, x_pad) - Y_off, final_state = self._compute_inter_chunk_blocks(A_cum, C_br, states, initial_states) + if cp_rank > 0: + prod_terms_for_i = [None] * cp_rank + prod_terms_for_i[cp_rank - 1] = torch.ones_like(ld_buf[0]) # ld_buf[0] is (B,H) + if cp_rank > 1: + decays_slice_S = ld_buf[1:cp_rank] + if decays_slice_S.numel() > 0: + # 1. Compute cumulative product on the flipped tensor, but do NOT flip it back. + cumprod_from_end = torch.cumprod(torch.flip(decays_slice_S, dims=[0]), dim=0) + # 2. Assign to the list using a reverse-indexed loop. + num_elements = cp_rank - 1 + for i in range(num_elements): + prod_terms_for_i[i] = cumprod_from_end[num_elements - 1 - i] + + for i in range(cp_rank): + current_decay_prod = prod_terms_for_i[i] + term_to_add = torch.einsum("bh, bhpn -> bhpn", current_decay_prod, lhs_buf[i]) + initial_states_loop_val += term_to_add.unsqueeze(1) # unsqueeze to (B,1,H,P,N) for broadcast + + added_init_state = torch.einsum("bhc, bihpn -> bchpn", + decay_chunk_calc[:, :, :-1, 0], # (B,H,C) where C is n_chunks + initial_states_loop_val # (B,1,H,P,N) broadcast i + ).to(torch.bfloat16).contiguous() + + zeros_for_cat = torch.zeros_like(new_states_prefix[:, :1]) # (B,1,H,P,N) + concatenated_new_states = torch.cat([zeros_for_cat, new_states_prefix], dim=1) # (B, C_chunks, H,P,N) + off_states = added_init_state + concatenated_new_states.contiguous() + + states_b = off_states.permute(0, 2, 1, 3, 4).reshape(-1, off_states.shape[3], off_states.shape[4]).transpose(-1, -2) + Cs_b = torch.bmm(C_b, states_b).reshape(C_r.shape[0], C_r.shape[1], C_r.shape[2], C_r.shape[3], states_b.shape[2]).contiguous() + state_decay_out_us = state_decay_out.unsqueeze(-1).contiguous() + Y_off = (Cs_b * state_decay_out_us).permute(0, 2, 3, 1, 4).contiguous() + + state_opts.final_state = None # Not handling inference state during training - # Output synthesis - state_opts.final_state = final_state return self._synthesize_output((Y_diag, Y_off, D), (pad_size, seq_len), state_opts) def _expand_dims(self, x, A, dt, B, C): diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 61b88be8d..5914a94fd 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -208,7 +208,7 @@ def _add_cp_args(parser): group = parser.add_argument_group(title='cp parallel') group.add_argument('--context-parallel-algo', type=str, default='ulysses_cp_algo', choices=['ulysses_cp_algo', 'megatron_cp_algo', 'hybrid_cp_algo', 'adaptive_cp_algo', - 'hybrid_adaptive_cp_algo'], help='context parallel algorithm') + 'hybrid_adaptive_cp_algo', 'mamba_cp_algo'], help='context parallel algorithm') group.add_argument('--ulysses-degree-in-cp', type=int, default=None) group.add_argument('--attention-mask-type', type=str, default='causal', choices=['causal', 'general'], help='context parallel attention mask type') @@ -289,7 +289,7 @@ def _validate_cp_args(args): if not args.use_mcore_models: raise AssertionError(f"Context parallel is only supported in Mcore.") - if args.context_parallel_algo == 'ulysses_cp_algo': + if args.context_parallel_algo == 'ulysses_cp_algo' or args.context_parallel_algo == 'mamba_cp_algo': assert args.seq_length % args.context_parallel_size == 0, f"sequence length must be divisible by context_parallel_size" _check_attention_head(args, args.context_parallel_size) diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index a5c8678fb..e990b8ff3 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -716,7 +716,7 @@ def get_batch_on_this_cp_rank(batch): batch = _get_batch_on_this_tp_y_cp_rank_in_megatron_cp(batch) else: batch = _get_batch_on_this_cp_rank_in_megatron_cp(batch) - elif args.context_parallel_algo == 'ulysses_cp_algo': + elif args.context_parallel_algo == 'ulysses_cp_algo' or args.context_parallel_algo == 'mamba_cp_algo': batch = _get_batch_on_this_cp_rank_in_ulysses_cp(batch) elif args.context_parallel_algo == 'hybrid_cp_algo': if args.cp_attention_mask_type == 'general': diff --git a/tests/pipeline/mamba2/mamba2_2.7b_tp1_pp1.sh b/tests/pipeline/mamba2/mamba2_2.7b_tp1_pp1.sh index 11d4f34a2..0440a9638 100644 --- a/tests/pipeline/mamba2/mamba2_2.7b_tp1_pp1.sh +++ b/tests/pipeline/mamba2/mamba2_2.7b_tp1_pp1.sh @@ -28,6 +28,8 @@ TOKENIZER_PATH="/data/mamba2-hf/mamba2_2.7b_from_8b.model" TP=1 PP=1 +CP=2 +CP_TYPE="mamba_cp_algo" NUM_LAYERS=64 SEQ_LEN=4096 MBS=1 @@ -42,6 +44,8 @@ MAMBA_ARGS=" --use-mcore-models \ --tensor-model-parallel-size ${TP} \ --pipeline-model-parallel-size ${PP} \ + --context-parallel-size ${CP} \ + --context-parallel-algo ${CP_TYPE} \ --sequence-parallel \ --num-layers ${NUM_LAYERS} \ --group-query-attention \ -- Gitee