diff --git a/mindspeed_llm/core/ssm/mamba_block.py b/mindspeed_llm/core/ssm/mamba_block.py index c15023bde8ecc6a230bf06cf2cd87ac6886cbcf4..6ba74022036c3b536b03c11bec91c29664ceb4e9 100644 --- a/mindspeed_llm/core/ssm/mamba_block.py +++ b/mindspeed_llm/core/ssm/mamba_block.py @@ -86,6 +86,11 @@ def _mamba_block_method_checkpointed_forward_func( inference_context=None, rotary_pos_emb=rotary_pos_emb, ) + # The attention layer (currently a simplified transformer layer) + # outputs a tuple of (hidden_states, context). Context is intended + # for cross-attention, and is not needed in our model. + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] return hidden_states return custom_forward