From 4e72213a304751172eb7025822e6c798ea8199f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=B1=9F=E6=B1=9F?= Date: Mon, 26 May 2025 11:06:59 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9patch=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/SenseVoice/TorchAir/diff.patch | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/SenseVoice/TorchAir/diff.patch b/ACL_PyTorch/built-in/audio/SenseVoice/TorchAir/diff.patch index ca2698dd0c..17eb818f7f 100755 --- a/ACL_PyTorch/built-in/audio/SenseVoice/TorchAir/diff.patch +++ b/ACL_PyTorch/built-in/audio/SenseVoice/TorchAir/diff.patch @@ -1,13 +1,3 @@ -From 008fbc757d824979c85cb65b5a5e8f0e3101f642 Mon Sep 17 00:00:00 2001 -From: shikang -Date: Sat, 8 Mar 2025 18:27:06 +0800 -Subject: [PATCH] add npu patch - ---- - funasr/auto/auto_model.py | 1 + - funasr/models/sense_voice/model.py | 56 +++++++++++++++--------------- - 2 files changed, 29 insertions(+), 28 deletions(-) - diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index f5cbe01f..7d8fff13 100644 --- a/funasr/auto/auto_model.py @@ -21,7 +11,7 @@ index f5cbe01f..7d8fff13 100644 import string import logging diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py -index 70cd02e3..8db7a46a 100644 +index 70cd02e3..7adca0de 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -4,6 +4,7 @@ import time @@ -34,7 +24,7 @@ index 70cd02e3..8db7a46a 100644 from torch.cuda.amp import autocast @@ -158,20 +159,10 @@ class MultiHeadedAttentionSANM(nn.Module): torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). - + """ - b, t, d = x.size() q_k_v = self.linear_q_k_v(x) @@ -52,12 +42,12 @@ index 70cd02e3..8db7a46a 100644 - return q_h, k_h, v_h, v + + return q, k, v - + def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): """Compute attention context vector. @@ -225,13 +216,22 @@ class MultiHeadedAttentionSANM(nn.Module): torch.Tensor: Output tensor (#batch, time1, d_model). - + """ - q_h, k_h, v_h, v = self.forward_qkv(x) + q, k, v = self.forward_qkv(x) @@ -67,7 +57,7 @@ index 70cd02e3..8db7a46a 100644 - att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) + att_outs = self.npu_flash_attention(q, k, v) # 使用npu PFA算子替换attention结构 return att_outs + fsmn_memory - + + def npu_flash_attention(self, query, key, value): + x = torch_npu.npu_prompt_flash_attention( + query, @@ -81,9 +71,9 @@ index 70cd02e3..8db7a46a 100644 + def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): """Compute scaled dot product attention. - + @@ -278,10 +278,10 @@ class LayerNorm(nn.LayerNorm): - + def forward(self, input): output = F.layer_norm( - input.float(), @@ -138,34 +128,40 @@ index 70cd02e3..8db7a46a 100644 + x = residual + self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm2(x) - + @@ -557,7 +557,7 @@ class SenseVoiceEncoderSmall(nn.Module): ): """Embed positions in tensor.""" maxlen = xs_pad.shape[1] - masks = sequence_mask(ilens, maxlen=maxlen, device=ilens.device)[:, None, :] + masks = None - + xs_pad *= self.output_size() ** 0.5 - + @@ -575,7 +575,7 @@ class SenseVoiceEncoderSmall(nn.Module): xs_pad = self.after_norm(xs_pad) - + # forward encoder2 - olens = masks.squeeze(1).sum(1).int() + olens = ilens.int() - + for layer_idx, encoder_layer in enumerate(self.tp_encoders): encoder_outs = encoder_layer(xs_pad, masks) @@ -769,7 +769,7 @@ class SenseVoiceSmall(nn.Module): speech = torch.cat((input_query, speech), dim=1) speech_lengths += 3 - + - encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) + encoder_out, encoder_out_lens = self.encoder(speech.to(torch.float16), speech_lengths.to(torch.float16)) - + return encoder_out, encoder_out_lens - --- -2.21.0 + +@@ -876,7 +876,7 @@ class SenseVoiceSmall(nn.Module): + speech_lengths += 3 + + # Encoder +- encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) ++ encoder_out, encoder_out_lens = self.encoder(speech.to(torch.float16), speech_lengths.to(torch.float16)) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] -- Gitee