diff --git a/mindspeed_llm/mindspore/core/models/common/embeddings/rotary_pos_embedding.py b/mindspeed_llm/mindspore/core/models/common/embeddings/rotary_pos_embedding.py index 3fde25c1589b23c78318faa6bed56c915aac9efb..e5a0b26531a7d3ff12397f937b9a620222102727 100644 --- a/mindspeed_llm/mindspore/core/models/common/embeddings/rotary_pos_embedding.py +++ b/mindspeed_llm/mindspore/core/models/common/embeddings/rotary_pos_embedding.py @@ -85,10 +85,6 @@ def apply_rotary_pos_emb_bshd_func( rot_dim = freqs.shape[-1] t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - if multi_latent_attention: - x1 = t[..., 0::2] - x2 = t[..., 1::2] - t = torch.cat((x1, x2), dim=-1) cos_ = (torch.cos(freqs) * _mscale).to(t.dtype) sin_ = (torch.sin(freqs) * _mscale).to(t.dtype) @@ -97,6 +93,10 @@ def apply_rotary_pos_emb_bshd_func( mode = 1 if rotary_interleaved else 0 t = torch_npu.npu_rotary_position_embedding(t.contiguous(), cos_, sin_, mode).to(t.dtype) else: + if multi_latent_attention: + x1 = t[..., 0::2] + x2 = t[..., 1::2] + t = torch.cat((x1, x2), dim=-1) t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_) return torch.cat((t, t_pass), dim=-1)