From 5c4cedd78ba54f9cddc61311019bf8c99f243a6e Mon Sep 17 00:00:00 2001 From: y30062407 Date: Wed, 30 Jul 2025 11:14:28 +0800 Subject: [PATCH] [mindspore][bugfix][master]fix precision align --- .../core/models/common/embeddings/rotary_pos_embedding.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspeed_llm/mindspore/core/models/common/embeddings/rotary_pos_embedding.py b/mindspeed_llm/mindspore/core/models/common/embeddings/rotary_pos_embedding.py index 3fde25c158..e5a0b26531 100644 --- a/mindspeed_llm/mindspore/core/models/common/embeddings/rotary_pos_embedding.py +++ b/mindspeed_llm/mindspore/core/models/common/embeddings/rotary_pos_embedding.py @@ -85,10 +85,6 @@ def apply_rotary_pos_emb_bshd_func( rot_dim = freqs.shape[-1] t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - if multi_latent_attention: - x1 = t[..., 0::2] - x2 = t[..., 1::2] - t = torch.cat((x1, x2), dim=-1) cos_ = (torch.cos(freqs) * _mscale).to(t.dtype) sin_ = (torch.sin(freqs) * _mscale).to(t.dtype) @@ -97,6 +93,10 @@ def apply_rotary_pos_emb_bshd_func( mode = 1 if rotary_interleaved else 0 t = torch_npu.npu_rotary_position_embedding(t.contiguous(), cos_, sin_, mode).to(t.dtype) else: + if multi_latent_attention: + x1 = t[..., 0::2] + x2 = t[..., 1::2] + t = torch.cat((x1, x2), dim=-1) t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_) return torch.cat((t, t_pass), dim=-1) -- Gitee