diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/README.md b/MindIE/MindIE-Torch/built-in/audio/Whisper/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8c8d30a3a40812a8e1333b1e4fcfd66803e4caa1
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/README.md
@@ -0,0 +1,168 @@
+# Whisper-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573)
+
+
+# 概述
+
+([来自开源代码仓](https://github.com/openai/whisper))Whisper是一种通用语音识别模型。它是在各种音频的大型数据集上训练的,也是一个多任务模型,可以执行多语言语音识别、语音翻译和语言识别。
+
+
+# 推理环境准备\[所有版本\]
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+ | 配套 | 版本 |
+ |---------| ------- |
+ | 固件与驱动 | 24.1.rc1 |
+ | CANN | 8.0.rc1 |
+ | Python | 3.10.13 |
+ | PyTorch | 2.1.0 |
+ | MindIE | 1.0.RC2.B071 |
+
+# 快速上手
+
+1. 源码下载
+ ```
+ git clone https://github.com/openai/whisper.git
+ cd whisper
+ git reset --hard ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab
+ ```
+2. 模型导出
+ ```
+ patch -p1 < ../trace_model.patch
+ pip3 install .
+ cd ..
+ wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
+ mkdir /tmp/models
+ whisper zh.wav --model tiny
+ ```
+ 执行上述步骤需要依赖`ffmpeg`,ubuntu下可通过`apt-get install ffmpeg`安装。完成上述步骤将在`/tmp/models`目录下生成`encoder.ts/onnx`, `decoder_prefill.ts/onnx`, `decoder_decode.onnx`6个文件。
+
+ 如需修改模型路径,可在打完补丁后手动修改`whisper/decoding.py`和`whisper/model.py`文件,后续步骤模型推理同样需要修改对应模型的载入路径。
+
+ *若下载过程中出现`[SSL: CERTIFICATE_VERIFY_FAILED]`相关报错,请在报错python文件中将以下代码添加至首行后再执行:*
+ ```python
+ import ssl
+ ssl._create_default_https_context = ssl._create_unverified_context
+ ```
+
+3. 模型编译
+ ```
+ python3 compile.py
+ ```
+ 执行完成后将在`/tmp/models`目录下生成`encoder_compiled.ts`, `language_detection_compiled.ts`, `decoder_prefill_compiled.ts`, `decoder_decode_compiled.ts`四个文件。
+
+ 参数说明:
+ - --model_path:导出的Torchscript模型路径,模型编译后保存在同一路径, 默认为`/tmp/models`。
+ - --beam_size: 集束搜索参数,默认为5。与推理参数保持一致,如模型导出时指定了该参数,在编译时需要保持一致。
+ - --nblocks: 模型Blocks参数,跟模型大小相关,tiny 4, base 6, small 12, medium 24, large 32。
+ - --hidden: 模型特征向量长度,跟模型大小相关,tiny 384, base 512, small 768, medium 1024, large 1280。
+ - --soc_version: 芯片类型,当前仅在Ascend310P3上调试。
+
+4. 模型推理
+ ```
+ cd whisper
+ git reset --hard ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab
+ patch -p1 < ../mindietorch_infer.patch
+ pip3 install .
+ cd ..
+ whisper zh.wav --model tiny
+ ```
+ 推理结束后,会在命令行打印出如下输出:
+ ```
+ [00:00.000 --> 00:04.480] 我認為跑步最重要的就是給我帶來了身體健康
+ ```
+ 如需要简体输出,可使用如下命令:
+ ```
+ whisper zh.wav --model tiny --initial_prompt "简体翻译:"
+ ```
+
+ 注:默认`芯片ID为0`,模型路径为`/tmp/models`。如需修改,可在打完补丁后手动修改`whisper/decoding.py`和`whisper/model.py`文件,可使用全局替换文件中的`npu:0`, `/tmp/models`, `torch_aie.set_device(0)`。
+
+
+# 模型推理性能精度
+
+1. 精度验证
+ ```
+ python3 precision_test.py
+ ```
+
+ 参数说明:
+ - --sim_threshold: 余弦相似度阈值,默认0.99。
+ - --ntokens: prefill阶段输入token数量,decode阶段缓存token数量,默认100。
+
+ 执行结束后,期望输出如下:
+ ```
+ === Compare the outputs of ONNX and AIE ===
+ Start comparing encoder...
+ Number of outputs to compare: 1
+ Number of outputs with cosine similarity > 0.99: 1
+ Number of outputs to compare: 3
+ Number of outputs with cosine similarity > 0.99: 3
+ Number of outputs to compare: 3
+ Number of outputs with cosine similarity > 0.99: 3
+ ```
+
+2. 性能验证
+
+ a) aie模型性能测试
+ ```
+ python perf_test_aie.py
+ ```
+
+ 执行结束后,期望输出如下:
+ ```
+ Encoder latency: 7.75 ms
+ Encoder throughput: 128.97 fps
+ Decoder prefill latency: 10.14 ms
+ Decoder prefill throughput: 98.63 fps
+ Decoder decode latency: 2.92 ms
+ Decoder decode throughput: 342.55 fps
+ ```
+
+ b) onnx模型性能测试
+ (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示:
+ ```shell
+ pip uninstall onnxruntime
+ pip install onnxruntime-gpu
+ ```
+ 验证onnxruntime-gpu是否安装成功:
+ ```python
+ import onnxruntime
+ print(onnxruntime.get_device()) # 若输出为GPU,则说明安装成功
+ ```
+ 执行性能测试
+ ```
+ python perf_test_onnx.py --use_gpu
+ ```
+
+ 参数说明:
+ - --use_gpu: 使能gpu推理,不加该选项默认cpu。
+
+ 执行结束后,期望输出如下:
+ ```
+ Encoder latency: 59.49 ms
+ Encoder throughput: 16.81 fps
+ Decoder prefill latency: 141.14 ms
+ Decoder prefill throughput: 7.09 fps
+ Decoder decode latency: 36.05 ms
+ Decoder decode throughput: 27.74 fps
+ ```
+
+
+ | 模型 | pt插件 - 310P性能(时延/吞吐率) | T4性能(时延/吞吐率) | A10性能(时延/吞吐率)|
+ |---------|--------------------------------|---------------------|--------------------|
+ | encoder | 7.75 ms / 128.97 fps | 9.31 ms / 107.47 fps | 4.21 ms / 237.50 fps |
+ | prefill | 10.14 ms / 98.63 fps | 72.08 ms / 13.87 fps | 45.15 ms / 22.15 fps |
+ | decode | 2.92 ms / 342.55 fps | 10.46 ms / 95.62 fps | 4.91 ms / 203.61 fps |
+
+ 注:在实际推理中encoder和prefill均调用一次,decode会调用多次(上面数据假设缓存token长度为100)。并且在whisper全流程推理中还包括后处理,cache重新排布等步骤,以上数据仅作参考。
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/compile.py b/MindIE/MindIE-Torch/built-in/audio/Whisper/compile.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9545ced634277d6ed4a3cc97cc6507524fe0284
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/compile.py
@@ -0,0 +1,133 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+
+import torch
+import mindietorch
+
+_N_MEL = 80
+_FRAMES = 3000
+_HALF_FRAMES = 1500
+_MAX_TOKEN = 224
+_KV_NUM = 2
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="mindietorch model compilation")
+ parser.add_argument("--model_path", default="/tmp/models")
+ parser.add_argument("--beam_size", type=int, default=5)
+ parser.add_argument("--nblocks", type=int, default=4)
+ parser.add_argument("--hidden", type=int, default=384)
+ parser.add_argument("--soc_version", default="Ascend310P3")
+ args = parser.parse_args()
+ return args
+
+def compile_and_save(ts_model, input_info, soc_version, save_path):
+ ts_model.eval()
+ mindie_model = mindietorch.compile(
+ ts_model,
+ inputs=input_info,
+ precision_policy=mindietorch._enums.PrecisionPolicy.FP16,
+ truncate_long_and_double=True,
+ allow_tensor_replace_int=True,
+ soc_version=soc_version,
+ optimization_level=0
+ )
+ mindie_model.save(save_path)
+
+def encoder(args):
+ ts_model = torch.jit.load(f"{args.model_path}/encoder.ts")
+ input_mel_info = mindietorch.Input([1, _N_MEL, _FRAMES])
+ input_info = [input_mel_info]
+ save_path = f"{args.model_path}/encoder_compiled.ts"
+ compile_and_save(ts_model, input_info, args.soc_version, save_path)
+
+def language(args):
+ ts_model = torch.jit.load(f"{args.model_path}/decoder_prefill.ts")
+ input_tokens_info = mindietorch.Input([1, 1])
+ input_audio_features_info = mindietorch.Input([1, _HALF_FRAMES, args.hidden])
+ input_pos_embed_info = mindietorch.Input([1, args.hidden])
+ input_info = [
+ input_tokens_info,
+ input_audio_features_info,
+ input_pos_embed_info,
+ ]
+ save_path = f"{args.model_path}/language_detection_compiled.ts"
+ compile_and_save(ts_model, input_info, args.soc_version, save_path)
+
+def prefill(args):
+ ts_model = torch.jit.load(f"{args.model_path}/decoder_prefill.ts")
+
+ input_tokens_info = mindietorch.Input(
+ min_shape=[args.beam_size, 1],
+ max_shape=[args.beam_size, _MAX_TOKEN]
+ )
+ input_audio_features_info = mindietorch.Input(
+ min_shape=[1, _HALF_FRAMES, args.hidden],
+ max_shape=[1, _HALF_FRAMES, args.hidden]
+ )
+ input_pos_embed_info = mindietorch.Input(
+ min_shape=[1, args.hidden],
+ max_shape=[_MAX_TOKEN, args.hidden]
+ )
+ input_info = [
+ input_tokens_info,
+ input_audio_features_info,
+ input_pos_embed_info,
+ ]
+ save_path = f"{args.model_path}/decoder_prefill_compiled.ts"
+ compile_and_save(ts_model, input_info, args.soc_version, save_path)
+
+def decode(args):
+ ts_model = torch.jit.load(f"{args.model_path}/decoder_decode.ts")
+
+ input_tokens_info = mindietorch.Input(
+ min_shape=[args.beam_size, 1],
+ max_shape=[args.beam_size, 1]
+ )
+ input_audio_features_info = mindietorch.Input(
+ min_shape=[1, _HALF_FRAMES, args.hidden],
+ max_shape=[1, _HALF_FRAMES, args.hidden]
+ )
+ input_pos_embed_info = mindietorch.Input(
+ min_shape=[args.hidden],
+ max_shape=[args.hidden]
+ )
+ input_cache_dyn_info = mindietorch.Input(
+ min_shape=(args.nblocks, _KV_NUM, args.beam_size, 1, args.hidden),
+ max_shape=(args.nblocks, _KV_NUM, args.beam_size, _MAX_TOKEN, args.hidden)
+ )
+ input_cache_sta_info = mindietorch.Input(
+ min_shape=[args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden],
+ max_shape=[args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden]
+ )
+
+ input_info = [
+ input_tokens_info,
+ input_audio_features_info,
+ input_pos_embed_info,
+ input_cache_dyn_info,
+ input_cache_sta_info
+ ]
+
+ save_path = f"{args.model_path}/decoder_decode_compiled.ts"
+ compile_and_save(ts_model, input_info, args.soc_version, save_path)
+
+def main():
+ args = parse_args()
+ for func in encoder, language, prefill, decode:
+ func(args)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/mindietorch_infer.patch b/MindIE/MindIE-Torch/built-in/audio/Whisper/mindietorch_infer.patch
new file mode 100644
index 0000000000000000000000000000000000000000..fc7f771847486b8542d41a2a54876304c481399e
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/mindietorch_infer.patch
@@ -0,0 +1,226 @@
+diff --git a/whisper/decoding.py b/whisper/decoding.py
+index 49485d0..4dccc86 100644
+--- a/whisper/decoding.py
++++ b/whisper/decoding.py
+@@ -6,6 +6,7 @@ import torch
+ import torch.nn.functional as F
+ from torch import Tensor
+ from torch.distributions import Categorical
++import mindietorch
+
+ from .audio import CHUNK_LENGTH
+ from .tokenizer import Tokenizer, get_tokenizer
+@@ -14,6 +15,7 @@ from .utils import compression_ratio
+ if TYPE_CHECKING:
+ from .model import Whisper
+
++mindietorch.set_device(0)
+
+ @torch.no_grad()
+ def detect_language(
+@@ -54,7 +56,7 @@ def detect_language(
+ # forward pass using a single token, startoftranscript
+ n_audio = mel.shape[0]
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
+- logits = model.logits(x, mel)[:, 0]
++ logits = model.logits(x, mel)[0][:, 0]
+
+ # collect detected languages; suppress all non-language tokens
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
+@@ -145,36 +147,35 @@ class PyTorchInference(Inference):
+ def __init__(self, model: "Whisper", initial_token_length: int):
+ self.model: "Whisper" = model
+ self.initial_token_length = initial_token_length
+- self.kv_cache = {}
+- self.hooks = []
+-
+- key_modules = [block.attn.key for block in self.model.decoder.blocks]
+- value_modules = [block.attn.value for block in self.model.decoder.blocks]
+- self.kv_modules = key_modules + value_modules
++ self.cache_dyn = None
++ self.cache_sta = None
+
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
+- if not self.kv_cache:
+- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
+-
+ if tokens.shape[-1] > self.initial_token_length:
+ # only need to use the last token except in the first forward pass
+ tokens = tokens[:, -1:]
++ pos_embed = self.model.decoder.positional_embedding[self.cache_dyn.shape[3]]
++ logits, cache_dyn, _ = self.model.decoder(
++ tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta)
++ self.cache_dyn = cache_dyn
++ else:
++ pos_embed = self.model.decoder.positional_embedding[:tokens.shape[-1]]
++ logits, cache_dyn, cache_sta = self.model.decoder(tokens, audio_features, pos_embed)
++ self.cache_dyn = cache_dyn
++ self.cache_sta = cache_sta
+
+- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
++ return logits
+
+ def cleanup_caching(self):
+- for hook in self.hooks:
+- hook.remove()
+-
+- self.kv_cache = {}
+- self.hooks = []
++ self.cache_dyn = None
++ self.cache_sta = None
+
+ def rearrange_kv_cache(self, source_indices):
+ if source_indices != list(range(len(source_indices))):
+- for module in self.kv_modules:
+- # update the key/value cache to contain the selected sequences
+- self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
+-
++ blocks = self.cache_dyn.shape[0]
++ for i in range(blocks):
++ for j in range(2): # k and v 2 items
++ self.cache_dyn[i][j] = self.cache_dyn[i][j][source_indices]
+
+ class SequenceRanker:
+ def rank(
+diff --git a/whisper/model.py b/whisper/model.py
+index a678283..c94a024 100644
+--- a/whisper/model.py
++++ b/whisper/model.py
+@@ -1,12 +1,14 @@
+ import base64
+ import gzip
+ from dataclasses import dataclass
++import os
+ from typing import Dict, Iterable, Optional
+
+ import numpy as np
+ import torch
+ import torch.nn.functional as F
+ from torch import Tensor, nn
++import mindietorch
+
+ from .decoding import decode as decode_function
+ from .decoding import detect_language as detect_language_function
+@@ -153,24 +155,19 @@ class AudioEncoder(nn.Module):
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+ )
+ self.ln_post = LayerNorm(n_state)
++ self.device = "npu:0"
++ self.mindietorch_encoder_model = torch.jit.load(
++ "/tmp/models/encoder_compiled.ts"
++ ).eval().to(self.device)
+
+ def forward(self, x: Tensor):
+ """
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+- x = F.gelu(self.conv1(x))
+- x = F.gelu(self.conv2(x))
+- x = x.permute(0, 2, 1)
+-
+- assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
+- x = (x + self.positional_embedding).to(x.dtype)
+-
+- for block in self.blocks:
+- x = block(x)
+-
+- x = self.ln_post(x)
+- return x
++ x = x.to(self.device)
++ x = self.mindietorch_encoder_model(x)
++ return x.cpu()
+
+
+ class TextDecoder(nn.Module):
+@@ -193,29 +190,58 @@ class TextDecoder(nn.Module):
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+ self.register_buffer("mask", mask, persistent=False)
+
+- def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+- """
+- x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+- the text tokens
+- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+- the encoded audio features to be attended on
+- """
+- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+- x = (
+- self.token_embedding(x)
+- + self.positional_embedding[offset : offset + x.shape[-1]]
+- )
+- x = x.to(xa.dtype)
+-
+- for block in self.blocks:
+- x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
++ self.device = "npu:0"
++ self.mindietorch_language_detection_model = torch.jit.load(
++ "/tmp/models/language_detection_compiled.ts"
++ ).eval().to(self.device)
++ self.mindietorch_prefill_model = torch.jit.load(
++ "/tmp/models/decoder_prefill_compiled.ts"
++ ).eval().to(self.device)
++ self.mindietorch_decode_model = torch.jit.load(
++ "/tmp/models/decoder_decode_compiled.ts"
++ ).eval().to(self.device)
+
+- x = self.ln(x)
+- logits = (
+- x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+- ).float()
+-
+- return logits
++ def forward(
++ self,
++ x: Tensor,
++ xa: Tensor,
++ pos_embed: Tensor = None,
++ cache_dyn: Tensor = None,
++ cache_sta: Tensor = None,
++ ):
++ if cache_dyn is None:
++ tokens_npu = x.float().to(self.device)
++ audio_features_npu = xa.to(self.device)
++ pos_embed_npu = pos_embed.to(self.device)
++ if x.shape[0] != 1:
++ logits, cache_dyn, cache_sta = self.mindietorch_prefill_model(
++ tokens_npu,
++ audio_features_npu,
++ pos_embed_npu
++ )
++ else:
++ logits, cache_dyn, cache_sta = self.mindietorch_language_detection_model(
++ tokens_npu,
++ audio_features_npu,
++ pos_embed_npu
++ )
++ logits = logits.cpu()
++ cache_dyn = cache_dyn.cpu()
++ else:
++ tokens_npu = x.float().to(self.device)
++ audio_features_npu = xa.to(self.device)
++ pos_embed_npu = pos_embed.to(self.device)
++ cache_dyn_npu = cache_dyn.to(self.device)
++ logits, cache_dyn, _ = self.mindietorch_decode_model(
++ tokens_npu,
++ audio_features_npu,
++ pos_embed_npu,
++ cache_dyn_npu,
++ cache_sta
++ )
++ logits = logits.cpu()
++ cache_dyn = cache_dyn.cpu()
++ return logits, cache_dyn, cache_sta
+
+
+ class Whisper(nn.Module):
+@@ -257,7 +283,8 @@ class Whisper(nn.Module):
+ return self.encoder(mel)
+
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
+- return self.decoder(tokens, audio_features)
++ pos_embed = self.decoder.positional_embedding[:tokens.shape[-1]]
++ return self.decoder(tokens, audio_features, pos_embed)
+
+ def forward(
+ self, mel: torch.Tensor, tokens: torch.Tensor
diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_aie.py b/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_aie.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f7cfdc2192a7f41c015e9259469e0101c726a5b
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_aie.py
@@ -0,0 +1,128 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import time
+import torch
+import mindietorch
+
+_N_MEL = 80
+_FRAMES = 3000
+_HALF_FRAMES = 1500
+_MAX_TOKEN = 224
+_KV_NUM = 2
+
+
+def test(inputs, model, stream, meta=""):
+ # warmup
+ for _ in range(10):
+ with mindietorch.npu.stream(stream):
+ model(*inputs)
+ stream.synchronize()
+
+ # performance test
+ num_infer = 100
+ start = time.time()
+ for _ in range(num_infer):
+ with mindietorch.npu.stream(stream):
+ model(*inputs)
+ stream.synchronize()
+ end = time.time()
+
+ print(f"{meta} latency: {(end - start) / num_infer * 1000:.2f} ms")
+ print(f"{meta} throughput: {num_infer / (end - start):.2f} fps")
+
+
+def test_encoder(args):
+ device = f'npu:{args.device_id}'
+ stream = mindietorch.npu.Stream(device)
+ model = torch.jit.load(args.encoder_aie_path)
+ model.eval()
+
+ inputs = [
+ torch.ones((1, _N_MEL, _FRAMES), dtype=torch.float32).to(device)
+ ]
+
+ test(inputs, model, stream, "Encoder")
+
+
+def test_decoder_prefill(args):
+ device = f'npu:{args.device_id}'
+ stream = mindietorch.npu.Stream(device)
+ model = torch.jit.load(args.decoder_prefill_aie_path)
+ model.eval()
+
+ assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}'
+
+ inputs = [
+ torch.ones((args.beam_size, args.ntokens), dtype=torch.float32).to(device),
+ torch.ones((1, _HALF_FRAMES, args.hidden), dtype=torch.float32).to(device),
+ torch.ones((args.ntokens, args.hidden), dtype=torch.float32).to(device)
+ ]
+
+ test(inputs, model, stream, "Decoder prefill")
+
+
+def test_decoder_decode(args):
+ device = f'npu:{args.device_id}'
+ stream = mindietorch.npu.Stream(device)
+ model = torch.jit.load(args.decoder_decode_aie_path)
+ model.eval()
+
+ inputs = [
+ torch.ones((args.beam_size, 1), dtype=torch.float32).to(device),
+ torch.ones((1, _HALF_FRAMES, args.hidden), dtype=torch.float32).to(device),
+ torch.ones((args.hidden), dtype=torch.float32).to(device),
+ torch.ones((args.nblocks, _KV_NUM, args.beam_size, args.ntokens, args.hidden), dtype=torch.float32).to(device),
+ torch.ones((args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden), dtype=torch.float32).to(device),
+ ]
+
+ test(inputs, model, stream, "Decoder decode")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--encoder_aie_path",
+ type=str, default="/tmp/models/encoder_compiled.ts"
+ )
+ parser.add_argument(
+ "--decoder_prefill_aie_path",
+ type=str, default="/tmp/models/decoder_prefill_compiled.ts"
+ )
+ parser.add_argument(
+ "--decoder_decode_aie_path",
+ type=str, default="/tmp/models/decoder_decode_compiled.ts"
+ )
+ parser.add_argument("--beam_size", type=int, default=5)
+ parser.add_argument("--ntokens", type=int, default=100)
+ parser.add_argument("--nblocks", type=int, default=4)
+ parser.add_argument("--hidden", type=int, default=384)
+ parser.add_argument("--device_id", type=int, help="NPU device id", default=0)
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ mindietorch.set_device(args.device_id)
+
+ for func in test_encoder, test_decoder_prefill, test_decoder_decode:
+ func(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_onnx.py b/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd893efd82fc5ac8987c7c790d37ea386cbf4845
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_onnx.py
@@ -0,0 +1,120 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import time
+import onnxruntime as ort
+import numpy as np
+
+_N_MEL = 80
+_FRAMES = 3000
+_MAX_TOKEN = 224
+_HALF_FRAMES = 1500
+_KV_NUM = 2
+
+
+def test(encoder_path, provider, output_names, onnx_inputs, meta=""):
+ onnx_model = ort.InferenceSession(
+ encoder_path,
+ providers=[provider]
+ )
+
+ # warmup
+ for _ in range(10):
+ onnx_model.run(output_names, onnx_inputs)
+ # performance test
+ num_infer = 100
+ start = time.time()
+ for _ in range(num_infer):
+ onnx_model.run(output_names, onnx_inputs)
+ end = time.time()
+
+ print(f"{meta} latency: {(end - start) / num_infer * 1000:.2f} ms")
+ print(f"{meta} throughput: {num_infer / (end - start):.2f} fps")
+
+
+def test_encoder(args, provider):
+ x = np.ones((1, _N_MEL, _FRAMES), dtype=np.float16 if args.use_gpu else np.float32)
+ onnx_inputs = {'mel': ort.OrtValue.ortvalue_from_numpy(x)}
+ output_names = ['ret']
+
+ test(args.encoder_onnx_path, provider, output_names, onnx_inputs, "Encoder")
+
+
+def test_decoder_prefill(args, provider):
+ assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}'
+ tokens = np.ones((args.beam_size, args.ntokens), dtype=np.int64)
+ audio_features = np.ones((1, _HALF_FRAMES, args.hidden), dtype=np.float16 if args.use_gpu else np.float32)
+ pos_embed = np.ones((args.ntokens, args.hidden), dtype=np.float32)
+ onnx_inputs = {
+ 'tokens': ort.OrtValue.ortvalue_from_numpy(tokens),
+ 'audio_features': ort.OrtValue.ortvalue_from_numpy(audio_features),
+ 'pos_embed': ort.OrtValue.ortvalue_from_numpy(pos_embed)
+ }
+ output_names = ["logits", "cache_dyn", "cache_sta"]
+
+ test(args.decoder_prefill_onnx_path, provider, output_names, onnx_inputs, "Decoder prefill")
+
+
+def test_decoder_decode(args, provider):
+ assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}'
+ tokens = np.ones((args.beam_size, 1), dtype=np.int64)
+ pos_embed = np.ones((args.hidden), dtype=np.float32)
+ cache_dyn = np.ones(
+ (args.nblocks, _KV_NUM, args.beam_size, args.ntokens, args.hidden),
+ dtype=np.float16 if args.use_gpu else np.float32
+ )
+ cache_sta = np.ones(
+ (args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden),
+ dtype=np.float16 if args.use_gpu else np.float32
+ )
+ onnx_inputs = {
+ 'tokens': ort.OrtValue.ortvalue_from_numpy(tokens), # audio_features onnx导出被折叠
+ 'pos_embed': ort.OrtValue.ortvalue_from_numpy(pos_embed),
+ 'cache_dyn': ort.OrtValue.ortvalue_from_numpy(cache_dyn),
+ 'cache_sta': ort.OrtValue.ortvalue_from_numpy(cache_sta)
+ }
+ output_names = ["logits", "new_cache_dyn", "new_cache_sta"]
+
+ test(args.decoder_decode_onnx_path, provider, output_names, onnx_inputs, "Decoder decode")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--encoder_onnx_path',type=str, default='/tmp/models/encoder.onnx')
+ parser.add_argument('--decoder_prefill_onnx_path',type=str, default='/tmp/models/decoder_prefill.onnx')
+ parser.add_argument('--decoder_decode_onnx_path',type=str, default='/tmp/models/decoder_decode.onnx')
+ parser.add_argument("--use_gpu", action="store_true")
+ parser.add_argument("--beam_size", type=int, default=5)
+ parser.add_argument("--ntokens", type=int, default=100)
+ parser.add_argument("--nblocks", type=int, default=4)
+ parser.add_argument("--hidden", type=int, default=384)
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+ if args.use_gpu:
+ provider = "CUDAExecutionProvider"
+ else:
+ provider = "CPUExecutionProvider"
+
+ for func in test_encoder, test_decoder_prefill, test_decoder_decode:
+ func(args, provider)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/precision_test.py b/MindIE/MindIE-Torch/built-in/audio/Whisper/precision_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..de9f03ae899e21317e8e05d310bfc071dd23a1bc
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/precision_test.py
@@ -0,0 +1,191 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import torch
+import torch.nn.functional as F
+import onnxruntime as ort
+import numpy as np
+import mindietorch
+
+_N_MEL = 80
+_FRAMES = 3000
+_MAX_TOKEN = 224
+_HALF_FRAMES = 1500
+_KV_NUM = 2
+
+def compare_onnx_aie_output(onnx_out, aie_out, sim_threshold=0.99):
+ num_sim = 0
+ for i, (a, b) in enumerate(zip(onnx_out, aie_out)):
+ a = a.reshape(1, -1).astype(np.float32)
+ b = b.reshape(1, -1)
+ sim = F.cosine_similarity(torch.from_numpy(a), b, dim=1)
+ if sim > sim_threshold:
+ num_sim += 1
+ else:
+ print(f'Output {i} similarity: {sim}')
+
+ print(f'Number of outputs to compare: {len(onnx_out)}')
+ print(f'Number of outputs with cosine similarity > {sim_threshold}: {num_sim}')
+
+
+def compare_encoder(args):
+ device = f'npu:{args.device_id}'
+
+ onnx_model = ort.InferenceSession(
+ args.encoder_onnx_path,
+ providers=["CPUExecutionProvider"]
+ )
+
+ x = np.ones((1, _N_MEL, _FRAMES), dtype=np.float32)
+ onnx_inputs = {'mel': ort.OrtValue.ortvalue_from_numpy(x)}
+ output_names = ['ret']
+ onnx_out = onnx_model.run(output_names, onnx_inputs)
+
+ aie_inputs = [x]
+ for i in range(len(aie_inputs)):
+ aie_inputs[i] = torch.from_numpy(aie_inputs[i]).to(device)
+
+ mindietorch.set_device(args.device_id)
+ stream = mindietorch.npu.Stream(device)
+ model = torch.jit.load(args.encoder_aie_path)
+ model.eval().to(device)
+
+ with mindietorch.npu.stream(stream):
+ aie_out = model(*aie_inputs)
+ stream.synchronize()
+
+ if isinstance(aie_out, tuple):
+ aie_out = (x.cpu() for x in aie_out)
+ else:
+ aie_out = aie_out.cpu()
+ compare_onnx_aie_output(onnx_out, aie_out, args.sim_threshold)
+
+
+def compare_decoder_prefill(args):
+ device = f'npu:{args.device_id}'
+
+ onnx_model = ort.InferenceSession(
+ args.decoder_prefill_onnx_path,
+ providers=["CPUExecutionProvider"]
+ )
+
+ assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}'
+ tokens = np.ones((args.beam_size, args.ntokens), dtype=np.int64)
+ audio_features = np.ones((1, _HALF_FRAMES, args.hidden), dtype=np.float32)
+ pos_embed = np.ones((args.ntokens, args.hidden), dtype=np.float32)
+ onnx_inputs = {
+ 'tokens': ort.OrtValue.ortvalue_from_numpy(tokens),
+ 'audio_features': ort.OrtValue.ortvalue_from_numpy(audio_features),
+ 'pos_embed': ort.OrtValue.ortvalue_from_numpy(pos_embed)
+ }
+ output_names = ["logits", "cache_dyn", "cache_sta"]
+ onnx_out = onnx_model.run(output_names, onnx_inputs)
+
+ aie_inputs = [tokens.astype(np.float32), audio_features, pos_embed]
+ for i in range(len(aie_inputs)):
+ aie_inputs[i] = torch.from_numpy(aie_inputs[i]).to(device)
+
+ mindietorch.set_device(args.device_id)
+ stream = mindietorch.npu.Stream(device)
+ model = torch.jit.load(args.decoder_prefill_aie_path)
+ model.eval().to(device)
+
+ with mindietorch.npu.stream(stream):
+ aie_out = model(*aie_inputs)
+ stream.synchronize()
+ if isinstance(aie_out, tuple):
+ aie_out = (x.cpu() for x in aie_out)
+ else:
+ aie_out = aie_out.cpu()
+ compare_onnx_aie_output(onnx_out, aie_out, args.sim_threshold)
+
+
+def compare_decoder_decode(args):
+ device = f'npu:{args.device_id}'
+
+ onnx_model = ort.InferenceSession(
+ args.decoder_decode_onnx_path,
+ providers=["CPUExecutionProvider"]
+ )
+
+ assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}'
+ tokens = np.ones((args.beam_size, 1), dtype=np.int64)
+ audio_features = np.ones((1, _HALF_FRAMES, args.hidden), dtype=np.float32)
+ pos_embed = np.ones((args.hidden), dtype=np.float32)
+ cache_dyn = np.ones((args.nblocks, _KV_NUM, args.beam_size, args.ntokens, args.hidden), dtype=np.float32)
+ cache_sta = np.ones((args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden), dtype=np.float32)
+ onnx_inputs = {
+ 'tokens': ort.OrtValue.ortvalue_from_numpy(tokens), # audio_features onnx导出被折叠
+ 'pos_embed': ort.OrtValue.ortvalue_from_numpy(pos_embed),
+ 'cache_dyn': ort.OrtValue.ortvalue_from_numpy(cache_dyn),
+ 'cache_sta': ort.OrtValue.ortvalue_from_numpy(cache_sta)
+ }
+
+ output_names = ["logits", "new_cache_dyn", "new_cache_sta"]
+ onnx_out = onnx_model.run(output_names, onnx_inputs)
+
+ aie_inputs = [tokens.astype(np.float32), audio_features, pos_embed, cache_dyn, cache_sta]
+ for i in range(len(aie_inputs)):
+ aie_inputs[i] = torch.from_numpy(aie_inputs[i]).to(device)
+
+ mindietorch.set_device(args.device_id)
+ stream = mindietorch.npu.Stream(device)
+ model = torch.jit.load(args.decoder_decode_aie_path)
+ model.eval().to(device)
+
+ with mindietorch.npu.stream(stream):
+ aie_out = model(*aie_inputs)
+ stream.synchronize()
+ if isinstance(aie_out, tuple):
+ aie_out = (x.cpu() for x in aie_out)
+ else:
+ aie_out = aie_out.cpu()
+ compare_onnx_aie_output(onnx_out, aie_out, args.sim_threshold)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ # encoder
+ parser.add_argument('--encoder_onnx_path',type=str, default='/tmp/models/encoder.onnx')
+ parser.add_argument('--encoder_aie_path', type=str, default='/tmp/models/encoder_compiled.ts')
+ # decoder_prefill
+ parser.add_argument('--decoder_prefill_onnx_path',type=str, default='/tmp/models/decoder_prefill.onnx')
+ parser.add_argument('--decoder_prefill_aie_path', type=str, default='/tmp/models/decoder_prefill_compiled.ts')
+ # decoder_decode
+ parser.add_argument('--decoder_decode_onnx_path',type=str, default='/tmp/models/decoder_decode.onnx')
+ parser.add_argument('--decoder_decode_aie_path', type=str, default='/tmp/models/decoder_decode_compiled.ts')
+ parser.add_argument('--sim_threshold', type=float, default=0.99)
+ parser.add_argument('--device_id', type=int, default=0)
+ parser.add_argument("--beam_size", type=int, default=5)
+ parser.add_argument("--ntokens", type=int, default=100)
+ parser.add_argument("--nblocks", type=int, default=4)
+ parser.add_argument("--hidden", type=int, default=384)
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ print('=== Compare the outputs of ONNX and AIE ===')
+
+ print('Start comparing encoder...')
+ funcs = [compare_encoder, compare_decoder_prefill, compare_decoder_decode]
+ for func in funcs:
+ func(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/trace_model.patch b/MindIE/MindIE-Torch/built-in/audio/Whisper/trace_model.patch
new file mode 100644
index 0000000000000000000000000000000000000000..a35756ff38f412308096baa0caf2da2880b71613
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/trace_model.patch
@@ -0,0 +1,343 @@
+diff --git a/whisper/decoding.py b/whisper/decoding.py
+index 49485d0..495fe45 100644
+--- a/whisper/decoding.py
++++ b/whisper/decoding.py
+@@ -2,6 +2,7 @@ from dataclasses import dataclass, field, replace
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+ import numpy as np
++import os
+ import torch
+ import torch.nn.functional as F
+ from torch import Tensor
+@@ -49,12 +50,24 @@ def detect_language(
+
+ # skip encoder forward pass if already-encoded audio features were given
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
++ encoder_ts_model = torch.jit.trace(model.encoder, mel)
++ encoder_ts_model.save(
++ "/tmp/models/encoder.ts")
++ torch.onnx.export(
++ model.encoder,
++ (mel),
++ "/tmp/models/encoder.onnx",
++ opset_version=11,
++ input_names=["mel"],
++ output_names=["ret"]
++ )
++
+ mel = model.encoder(mel)
+
+ # forward pass using a single token, startoftranscript
+ n_audio = mel.shape[0]
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
+- logits = model.logits(x, mel)[:, 0]
++ logits = model.logits(x, mel)[0][:, 0]
+
+ # collect detected languages; suppress all non-language tokens
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
+@@ -145,36 +158,74 @@ class PyTorchInference(Inference):
+ def __init__(self, model: "Whisper", initial_token_length: int):
+ self.model: "Whisper" = model
+ self.initial_token_length = initial_token_length
+- self.kv_cache = {}
+- self.hooks = []
+-
+- key_modules = [block.attn.key for block in self.model.decoder.blocks]
+- value_modules = [block.attn.value for block in self.model.decoder.blocks]
+- self.kv_modules = key_modules + value_modules
++ self.cache_dyn = None
++ self.cache_sta = None
+
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
+- if not self.kv_cache:
+- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
+-
+ if tokens.shape[-1] > self.initial_token_length:
+ # only need to use the last token except in the first forward pass
+ tokens = tokens[:, -1:]
++ pos_embed = self.model.decoder.positional_embedding[self.cache_dyn.shape[3]]
++ torch.onnx.export(
++ self.model.decoder,
++ (tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta),
++ "/tmp/models/decoder_decode.onnx",
++ opset_version=11,
++ input_names=["tokens", "audio_features", "pos_embed", "cache_dyn", "cache_sta"],
++ output_names=["logits", "new_cache_dyn", "new_cache_sta"],
++ dynamic_axes={
++ "cache_dyn": {3: "ntokens"},
++ "new_cache_dyn": {3: "ntokens"}
++ }
++ )
++ decoder_decode_ts_model = torch.jit.trace(
++ self.model.decoder,
++ (tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta)
++ )
++ decoder_decode_ts_model.save(
++ "/tmp/models/decoder_decode.ts")
++ logits, cache_dyn, _ = self.model.decoder(
++ tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta)
++ os.sys.exit(0)
++ self.cache_dyn = cache_dyn
++ else:
++ pos_embed = self.model.decoder.positional_embedding[:tokens.shape[-1]]
++ torch.onnx.export(
++ self.model.decoder,
++ (tokens, audio_features, pos_embed),
++ "/tmp/models/decoder_prefill.onnx",
++ opset_version=11,
++ input_names=["tokens", "audio_features", "pos_embed"],
++ output_names=["logits", "cache_dyn", "cache_sta"],
++ dynamic_axes={
++ "tokens": {1: "ntokens"},
++ "pos_embed": {0: "ntokens"},
++ "logits": {1: "ntokens"},
++ "cache_dyn": {3: "ntokens"}
++ }
++ )
++ decoder_prefill_ts_model = torch.jit.trace(
++ self.model.decoder,
++ (tokens, audio_features, pos_embed)
++ )
++ decoder_prefill_ts_model.save(
++ "/tmp/models/decoder_prefill.ts")
++ logits, cache_dyn, cache_sta = self.model.decoder(tokens, audio_features, pos_embed)
++ self.cache_dyn = cache_dyn
++ self.cache_sta = cache_sta
+
+- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
++ return logits
+
+ def cleanup_caching(self):
+- for hook in self.hooks:
+- hook.remove()
+-
+- self.kv_cache = {}
+- self.hooks = []
++ self.cache_dyn = None
++ self.cache_sta = None
+
+ def rearrange_kv_cache(self, source_indices):
+ if source_indices != list(range(len(source_indices))):
+- for module in self.kv_modules:
+- # update the key/value cache to contain the selected sequences
+- self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
+-
++ blocks = self.cache_dyn.shape[0]
++ for i in range(blocks):
++ for j in range(2): # k and v 2 items
++ self.cache_dyn[i][j] = self.cache_dyn[i][j][source_indices]
+
+ class SequenceRanker:
+ def rank(
+diff --git a/whisper/model.py b/whisper/model.py
+index a678283..2a95e28 100644
+--- a/whisper/model.py
++++ b/whisper/model.py
+@@ -1,6 +1,7 @@
+ import base64
+ import gzip
+ from dataclasses import dataclass
++import os
+ from typing import Dict, Iterable, Optional
+
+ import numpy as np
+@@ -68,6 +69,63 @@ class MultiHeadAttention(nn.Module):
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
+
++ def encoder_forward(self, x: Tensor):
++ q = self.query(x)
++ k = self.key(x)
++ v = self.value(x)
++ wv, qk = self.qkv_attention(q, k, v)
++ return self.out(wv)
++
++ def prefill_self_attn_forward(
++ self,
++ x: Tensor,
++ mask: Tensor,
++ ):
++ q = self.query(x)
++ k = self.key(x)
++ v = self.value(x)
++ cache_dyn = torch.stack([k, v])
++ wv, _ = self.qkv_attention(q, k, v, mask)
++ return self.out(wv), cache_dyn
++
++ def prefill_cross_attn_forward(
++ self,
++ x: Tensor,
++ xa: Tensor,
++ ):
++ q = self.query(x)
++ k = self.key(xa)
++ v = self.value(xa)
++ cache_sta = torch.stack([k, v])
++ wv, _ = self.qkv_attention(q, k, v)
++ return self.out(wv), cache_sta
++
++ def decode_self_attn_forward(
++ self,
++ x: Tensor,
++ mask: Tensor,
++ cache_dyn: Tensor
++ ):
++ q = self.query(x)
++ token_k = self.key(x)
++ k = torch.cat([cache_dyn[0], token_k], dim=1).detach()
++ token_v = self.value(x)
++ v = torch.cat([cache_dyn[1], token_v], dim=1).detach()
++ new_cache_dyn = torch.stack([k, v])
++ wv, _ = self.qkv_attention(q, k, v, mask)
++ return self.out(wv), new_cache_dyn
++
++ def decode_cross_attn_forward(
++ self,
++ x: Tensor,
++ cache_sta: Tensor
++ ):
++ q = self.query(x)
++ k = cache_sta[0]
++ v = cache_sta[1]
++ wv, _ = self.qkv_attention(q, k, v)
++ return self.out(wv)
++
+ def forward(
+ self,
+ x: Tensor,
+@@ -126,6 +184,39 @@ class ResidualAttentionBlock(nn.Module):
+ )
+ self.mlp_ln = LayerNorm(n_state)
+
++ def encoder_forward(self, x: Tensor):
++ x = x + self.attn.encoder_forward(self.attn_ln(x))
++ x = x + self.mlp(self.mlp_ln(x))
++ return x
++
++ def prefill_forward(
++ self,
++ x: Tensor,
++ xa: Tensor,
++ mask: Tensor,
++ ):
++ self_attn_out, new_cache_dyn = self.attn.prefill_self_attn_forward(self.attn_ln(x), mask)
++ x = x + self_attn_out
++ cross_attn_out, new_cache_sta = self.cross_attn.prefill_cross_attn_forward(self.cross_attn_ln(x), xa)
++ x = x + cross_attn_out
++ x = x + self.mlp(self.mlp_ln(x))
++ return x, new_cache_dyn, new_cache_sta
++
++ def decode_forward(
++ self,
++ x: Tensor,
++ xa: Tensor,
++ mask: Tensor,
++ cache_dyn: Tensor,
++ cache_sta: Tensor
++ ):
++ self_attn_out, new_cache_dyn = self.attn.decode_self_attn_forward(self.attn_ln(x), mask, cache_dyn)
++ x = x + self_attn_out
++ cross_attn_out = self.cross_attn.decode_cross_attn_forward(self.cross_attn_ln(x), cache_sta)
++ x = x + cross_attn_out
++ x = x + self.mlp(self.mlp_ln(x))
++ return x, new_cache_dyn
++
+ def forward(
+ self,
+ x: Tensor,
+@@ -163,11 +254,10 @@ class AudioEncoder(nn.Module):
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1)
+
+- assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
+ x = (x + self.positional_embedding).to(x.dtype)
+
+ for block in self.blocks:
+- x = block(x)
++ x = block.encoder_forward(x)
+
+ x = self.ln_post(x)
+ return x
+@@ -193,29 +283,56 @@ class TextDecoder(nn.Module):
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+ self.register_buffer("mask", mask, persistent=False)
+
+- def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+- """
+- x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+- the text tokens
+- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+- the encoded audio features to be attended on
+- """
+- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+- x = (
+- self.token_embedding(x)
+- + self.positional_embedding[offset : offset + x.shape[-1]]
+- )
+- x = x.to(xa.dtype)
++ def prefill(self, x: Tensor, xa: Tensor, pos_embed: Tensor):
++ x = (self.token_embedding(x) + pos_embed).to(xa.dtype)
+
++ cache_dyn_list = []
++ cache_sta_list = []
+ for block in self.blocks:
+- x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
++ x, new_cache_dyn, new_cache_sta = block.prefill_forward(x, xa, self.mask)
++ cache_dyn_list.append(new_cache_dyn)
++ cache_sta_list.append(new_cache_sta)
++
++ cache_dyn = torch.stack(cache_dyn_list)
++ cache_sta = torch.stack(cache_sta_list)
+
+ x = self.ln(x)
+ logits = (
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+ ).float()
+
+- return logits
++ return logits, cache_dyn, cache_sta
++
++ def decode(self, x: Tensor, xa: Tensor, pos_embed: Tensor, cache_dyn: Tensor, cache_sta: Tensor):
++ x = (self.token_embedding(x) + pos_embed).to(xa.dtype)
++
++ cache_dyn_list = []
++ for idx, block in enumerate(self.blocks):
++ x, new_cache_dyn = block.decode_forward(x, xa, self.mask, cache_dyn[idx], cache_sta[idx])
++ cache_dyn_list.append(new_cache_dyn)
++
++ new_cache_dyn = torch.stack(cache_dyn_list)
++
++ x = self.ln(x)
++ logits = (
++ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
++ ).float()
++
++ return logits, new_cache_dyn
++
++ def forward(
++ self,
++ x: Tensor,
++ xa: Tensor,
++ pos_embed: Tensor = None,
++ cache_dyn: Tensor = None,
++ cache_sta: Tensor = None,
++ ):
++ if cache_dyn is None:
++ logits, cache_dyn, cache_sta = self.prefill(x, xa, pos_embed)
++ else:
++ logits, cache_dyn = self.decode(x, xa, pos_embed, cache_dyn, cache_sta)
++ return logits, cache_dyn, cache_sta
+
+
+ class Whisper(nn.Module):
+@@ -257,7 +374,8 @@ class Whisper(nn.Module):
+ return self.encoder(mel)
+
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
+- return self.decoder(tokens, audio_features)
++ pos_embed = self.decoder.positional_embedding[:tokens.shape[-1]]
++ return self.decoder(tokens, audio_features, pos_embed)
+
+ def forward(
+ self, mel: torch.Tensor, tokens: torch.Tensor