diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/README.md b/MindIE/MindIE-Torch/built-in/audio/Whisper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8c8d30a3a40812a8e1333b1e4fcfd66803e4caa1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/README.md @@ -0,0 +1,168 @@ +# Whisper-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + +- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + +([来自开源代码仓](https://github.com/openai/whisper))Whisper是一种通用语音识别模型。它是在各种音频的大型数据集上训练的,也是一个多任务模型,可以执行多语言语音识别、语音翻译和语言识别。 + + +# 推理环境准备\[所有版本\] + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + | 配套 | 版本 | + |---------| ------- | + | 固件与驱动 | 24.1.rc1 | + | CANN | 8.0.rc1 | + | Python | 3.10.13 | + | PyTorch | 2.1.0 | + | MindIE | 1.0.RC2.B071 | + +# 快速上手 + +1. 源码下载 + ``` + git clone https://github.com/openai/whisper.git + cd whisper + git reset --hard ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab + ``` +2. 模型导出 + ``` + patch -p1 < ../trace_model.patch + pip3 install . + cd .. + wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav + mkdir /tmp/models + whisper zh.wav --model tiny + ``` + 执行上述步骤需要依赖`ffmpeg`,ubuntu下可通过`apt-get install ffmpeg`安装。完成上述步骤将在`/tmp/models`目录下生成`encoder.ts/onnx`, `decoder_prefill.ts/onnx`, `decoder_decode.onnx`6个文件。 + + 如需修改模型路径,可在打完补丁后手动修改`whisper/decoding.py`和`whisper/model.py`文件,后续步骤模型推理同样需要修改对应模型的载入路径。 + + *若下载过程中出现`[SSL: CERTIFICATE_VERIFY_FAILED]`相关报错,请在报错python文件中将以下代码添加至首行后再执行:* + ```python + import ssl + ssl._create_default_https_context = ssl._create_unverified_context + ``` + +3. 模型编译 + ``` + python3 compile.py + ``` + 执行完成后将在`/tmp/models`目录下生成`encoder_compiled.ts`, `language_detection_compiled.ts`, `decoder_prefill_compiled.ts`, `decoder_decode_compiled.ts`四个文件。 + + 参数说明: + - --model_path:导出的Torchscript模型路径,模型编译后保存在同一路径, 默认为`/tmp/models`。 + - --beam_size: 集束搜索参数,默认为5。与推理参数保持一致,如模型导出时指定了该参数,在编译时需要保持一致。 + - --nblocks: 模型Blocks参数,跟模型大小相关,tiny 4, base 6, small 12, medium 24, large 32。 + - --hidden: 模型特征向量长度,跟模型大小相关,tiny 384, base 512, small 768, medium 1024, large 1280。 + - --soc_version: 芯片类型,当前仅在Ascend310P3上调试。 + +4. 模型推理 + ``` + cd whisper + git reset --hard ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab + patch -p1 < ../mindietorch_infer.patch + pip3 install . + cd .. + whisper zh.wav --model tiny + ``` + 推理结束后,会在命令行打印出如下输出: + ``` + [00:00.000 --> 00:04.480] 我認為跑步最重要的就是給我帶來了身體健康 + ``` + 如需要简体输出,可使用如下命令: + ``` + whisper zh.wav --model tiny --initial_prompt "简体翻译:" + ``` + + 注:默认`芯片ID为0`,模型路径为`/tmp/models`。如需修改,可在打完补丁后手动修改`whisper/decoding.py`和`whisper/model.py`文件,可使用全局替换文件中的`npu:0`, `/tmp/models`, `torch_aie.set_device(0)`。 + + +# 模型推理性能精度 + +1. 精度验证 + ``` + python3 precision_test.py + ``` + + 参数说明: + - --sim_threshold: 余弦相似度阈值,默认0.99。 + - --ntokens: prefill阶段输入token数量,decode阶段缓存token数量,默认100。 + + 执行结束后,期望输出如下: + ``` + === Compare the outputs of ONNX and AIE === + Start comparing encoder... + Number of outputs to compare: 1 + Number of outputs with cosine similarity > 0.99: 1 + Number of outputs to compare: 3 + Number of outputs with cosine similarity > 0.99: 3 + Number of outputs to compare: 3 + Number of outputs with cosine similarity > 0.99: 3 + ``` + +2. 性能验证 + + a) aie模型性能测试 + ``` + python perf_test_aie.py + ``` + + 执行结束后,期望输出如下: + ``` + Encoder latency: 7.75 ms + Encoder throughput: 128.97 fps + Decoder prefill latency: 10.14 ms + Decoder prefill throughput: 98.63 fps + Decoder decode latency: 2.92 ms + Decoder decode throughput: 342.55 fps + ``` + + b) onnx模型性能测试 + (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示: + ```shell + pip uninstall onnxruntime + pip install onnxruntime-gpu + ``` + 验证onnxruntime-gpu是否安装成功: + ```python + import onnxruntime + print(onnxruntime.get_device()) # 若输出为GPU,则说明安装成功 + ``` + 执行性能测试 + ``` + python perf_test_onnx.py --use_gpu + ``` + + 参数说明: + - --use_gpu: 使能gpu推理,不加该选项默认cpu。 + + 执行结束后,期望输出如下: + ``` + Encoder latency: 59.49 ms + Encoder throughput: 16.81 fps + Decoder prefill latency: 141.14 ms + Decoder prefill throughput: 7.09 fps + Decoder decode latency: 36.05 ms + Decoder decode throughput: 27.74 fps + ``` + + + | 模型 | pt插件 - 310P性能(时延/吞吐率) | T4性能(时延/吞吐率) | A10性能(时延/吞吐率)| + |---------|--------------------------------|---------------------|--------------------| + | encoder | 7.75 ms / 128.97 fps | 9.31 ms / 107.47 fps | 4.21 ms / 237.50 fps | + | prefill | 10.14 ms / 98.63 fps | 72.08 ms / 13.87 fps | 45.15 ms / 22.15 fps | + | decode | 2.92 ms / 342.55 fps | 10.46 ms / 95.62 fps | 4.91 ms / 203.61 fps | + + 注:在实际推理中encoder和prefill均调用一次,decode会调用多次(上面数据假设缓存token长度为100)。并且在whisper全流程推理中还包括后处理,cache重新排布等步骤,以上数据仅作参考。 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/compile.py b/MindIE/MindIE-Torch/built-in/audio/Whisper/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..a9545ced634277d6ed4a3cc97cc6507524fe0284 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/compile.py @@ -0,0 +1,133 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +import mindietorch + +_N_MEL = 80 +_FRAMES = 3000 +_HALF_FRAMES = 1500 +_MAX_TOKEN = 224 +_KV_NUM = 2 + +def parse_args(): + parser = argparse.ArgumentParser(description="mindietorch model compilation") + parser.add_argument("--model_path", default="/tmp/models") + parser.add_argument("--beam_size", type=int, default=5) + parser.add_argument("--nblocks", type=int, default=4) + parser.add_argument("--hidden", type=int, default=384) + parser.add_argument("--soc_version", default="Ascend310P3") + args = parser.parse_args() + return args + +def compile_and_save(ts_model, input_info, soc_version, save_path): + ts_model.eval() + mindie_model = mindietorch.compile( + ts_model, + inputs=input_info, + precision_policy=mindietorch._enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + allow_tensor_replace_int=True, + soc_version=soc_version, + optimization_level=0 + ) + mindie_model.save(save_path) + +def encoder(args): + ts_model = torch.jit.load(f"{args.model_path}/encoder.ts") + input_mel_info = mindietorch.Input([1, _N_MEL, _FRAMES]) + input_info = [input_mel_info] + save_path = f"{args.model_path}/encoder_compiled.ts" + compile_and_save(ts_model, input_info, args.soc_version, save_path) + +def language(args): + ts_model = torch.jit.load(f"{args.model_path}/decoder_prefill.ts") + input_tokens_info = mindietorch.Input([1, 1]) + input_audio_features_info = mindietorch.Input([1, _HALF_FRAMES, args.hidden]) + input_pos_embed_info = mindietorch.Input([1, args.hidden]) + input_info = [ + input_tokens_info, + input_audio_features_info, + input_pos_embed_info, + ] + save_path = f"{args.model_path}/language_detection_compiled.ts" + compile_and_save(ts_model, input_info, args.soc_version, save_path) + +def prefill(args): + ts_model = torch.jit.load(f"{args.model_path}/decoder_prefill.ts") + + input_tokens_info = mindietorch.Input( + min_shape=[args.beam_size, 1], + max_shape=[args.beam_size, _MAX_TOKEN] + ) + input_audio_features_info = mindietorch.Input( + min_shape=[1, _HALF_FRAMES, args.hidden], + max_shape=[1, _HALF_FRAMES, args.hidden] + ) + input_pos_embed_info = mindietorch.Input( + min_shape=[1, args.hidden], + max_shape=[_MAX_TOKEN, args.hidden] + ) + input_info = [ + input_tokens_info, + input_audio_features_info, + input_pos_embed_info, + ] + save_path = f"{args.model_path}/decoder_prefill_compiled.ts" + compile_and_save(ts_model, input_info, args.soc_version, save_path) + +def decode(args): + ts_model = torch.jit.load(f"{args.model_path}/decoder_decode.ts") + + input_tokens_info = mindietorch.Input( + min_shape=[args.beam_size, 1], + max_shape=[args.beam_size, 1] + ) + input_audio_features_info = mindietorch.Input( + min_shape=[1, _HALF_FRAMES, args.hidden], + max_shape=[1, _HALF_FRAMES, args.hidden] + ) + input_pos_embed_info = mindietorch.Input( + min_shape=[args.hidden], + max_shape=[args.hidden] + ) + input_cache_dyn_info = mindietorch.Input( + min_shape=(args.nblocks, _KV_NUM, args.beam_size, 1, args.hidden), + max_shape=(args.nblocks, _KV_NUM, args.beam_size, _MAX_TOKEN, args.hidden) + ) + input_cache_sta_info = mindietorch.Input( + min_shape=[args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden], + max_shape=[args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden] + ) + + input_info = [ + input_tokens_info, + input_audio_features_info, + input_pos_embed_info, + input_cache_dyn_info, + input_cache_sta_info + ] + + save_path = f"{args.model_path}/decoder_decode_compiled.ts" + compile_and_save(ts_model, input_info, args.soc_version, save_path) + +def main(): + args = parse_args() + for func in encoder, language, prefill, decode: + func(args) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/mindietorch_infer.patch b/MindIE/MindIE-Torch/built-in/audio/Whisper/mindietorch_infer.patch new file mode 100644 index 0000000000000000000000000000000000000000..fc7f771847486b8542d41a2a54876304c481399e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/mindietorch_infer.patch @@ -0,0 +1,226 @@ +diff --git a/whisper/decoding.py b/whisper/decoding.py +index 49485d0..4dccc86 100644 +--- a/whisper/decoding.py ++++ b/whisper/decoding.py +@@ -6,6 +6,7 @@ import torch + import torch.nn.functional as F + from torch import Tensor + from torch.distributions import Categorical ++import mindietorch + + from .audio import CHUNK_LENGTH + from .tokenizer import Tokenizer, get_tokenizer +@@ -14,6 +15,7 @@ from .utils import compression_ratio + if TYPE_CHECKING: + from .model import Whisper + ++mindietorch.set_device(0) + + @torch.no_grad() + def detect_language( +@@ -54,7 +56,7 @@ def detect_language( + # forward pass using a single token, startoftranscript + n_audio = mel.shape[0] + x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] +- logits = model.logits(x, mel)[:, 0] ++ logits = model.logits(x, mel)[0][:, 0] + + # collect detected languages; suppress all non-language tokens + mask = torch.ones(logits.shape[-1], dtype=torch.bool) +@@ -145,36 +147,35 @@ class PyTorchInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length +- self.kv_cache = {} +- self.hooks = [] +- +- key_modules = [block.attn.key for block in self.model.decoder.blocks] +- value_modules = [block.attn.value for block in self.model.decoder.blocks] +- self.kv_modules = key_modules + value_modules ++ self.cache_dyn = None ++ self.cache_sta = None + + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: +- if not self.kv_cache: +- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() +- + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] ++ pos_embed = self.model.decoder.positional_embedding[self.cache_dyn.shape[3]] ++ logits, cache_dyn, _ = self.model.decoder( ++ tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta) ++ self.cache_dyn = cache_dyn ++ else: ++ pos_embed = self.model.decoder.positional_embedding[:tokens.shape[-1]] ++ logits, cache_dyn, cache_sta = self.model.decoder(tokens, audio_features, pos_embed) ++ self.cache_dyn = cache_dyn ++ self.cache_sta = cache_sta + +- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) ++ return logits + + def cleanup_caching(self): +- for hook in self.hooks: +- hook.remove() +- +- self.kv_cache = {} +- self.hooks = [] ++ self.cache_dyn = None ++ self.cache_sta = None + + def rearrange_kv_cache(self, source_indices): + if source_indices != list(range(len(source_indices))): +- for module in self.kv_modules: +- # update the key/value cache to contain the selected sequences +- self.kv_cache[module] = self.kv_cache[module][source_indices].detach() +- ++ blocks = self.cache_dyn.shape[0] ++ for i in range(blocks): ++ for j in range(2): # k and v 2 items ++ self.cache_dyn[i][j] = self.cache_dyn[i][j][source_indices] + + class SequenceRanker: + def rank( +diff --git a/whisper/model.py b/whisper/model.py +index a678283..c94a024 100644 +--- a/whisper/model.py ++++ b/whisper/model.py +@@ -1,12 +1,14 @@ + import base64 + import gzip + from dataclasses import dataclass ++import os + from typing import Dict, Iterable, Optional + + import numpy as np + import torch + import torch.nn.functional as F + from torch import Tensor, nn ++import mindietorch + + from .decoding import decode as decode_function + from .decoding import detect_language as detect_language_function +@@ -153,24 +155,19 @@ class AudioEncoder(nn.Module): + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) ++ self.device = "npu:0" ++ self.mindietorch_encoder_model = torch.jit.load( ++ "/tmp/models/encoder_compiled.ts" ++ ).eval().to(self.device) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ +- x = F.gelu(self.conv1(x)) +- x = F.gelu(self.conv2(x)) +- x = x.permute(0, 2, 1) +- +- assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" +- x = (x + self.positional_embedding).to(x.dtype) +- +- for block in self.blocks: +- x = block(x) +- +- x = self.ln_post(x) +- return x ++ x = x.to(self.device) ++ x = self.mindietorch_encoder_model(x) ++ return x.cpu() + + + class TextDecoder(nn.Module): +@@ -193,29 +190,58 @@ class TextDecoder(nn.Module): + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + +- def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): +- """ +- x : torch.LongTensor, shape = (batch_size, <= n_ctx) +- the text tokens +- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) +- the encoded audio features to be attended on +- """ +- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 +- x = ( +- self.token_embedding(x) +- + self.positional_embedding[offset : offset + x.shape[-1]] +- ) +- x = x.to(xa.dtype) +- +- for block in self.blocks: +- x = block(x, xa, mask=self.mask, kv_cache=kv_cache) ++ self.device = "npu:0" ++ self.mindietorch_language_detection_model = torch.jit.load( ++ "/tmp/models/language_detection_compiled.ts" ++ ).eval().to(self.device) ++ self.mindietorch_prefill_model = torch.jit.load( ++ "/tmp/models/decoder_prefill_compiled.ts" ++ ).eval().to(self.device) ++ self.mindietorch_decode_model = torch.jit.load( ++ "/tmp/models/decoder_decode_compiled.ts" ++ ).eval().to(self.device) + +- x = self.ln(x) +- logits = ( +- x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) +- ).float() +- +- return logits ++ def forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ pos_embed: Tensor = None, ++ cache_dyn: Tensor = None, ++ cache_sta: Tensor = None, ++ ): ++ if cache_dyn is None: ++ tokens_npu = x.float().to(self.device) ++ audio_features_npu = xa.to(self.device) ++ pos_embed_npu = pos_embed.to(self.device) ++ if x.shape[0] != 1: ++ logits, cache_dyn, cache_sta = self.mindietorch_prefill_model( ++ tokens_npu, ++ audio_features_npu, ++ pos_embed_npu ++ ) ++ else: ++ logits, cache_dyn, cache_sta = self.mindietorch_language_detection_model( ++ tokens_npu, ++ audio_features_npu, ++ pos_embed_npu ++ ) ++ logits = logits.cpu() ++ cache_dyn = cache_dyn.cpu() ++ else: ++ tokens_npu = x.float().to(self.device) ++ audio_features_npu = xa.to(self.device) ++ pos_embed_npu = pos_embed.to(self.device) ++ cache_dyn_npu = cache_dyn.to(self.device) ++ logits, cache_dyn, _ = self.mindietorch_decode_model( ++ tokens_npu, ++ audio_features_npu, ++ pos_embed_npu, ++ cache_dyn_npu, ++ cache_sta ++ ) ++ logits = logits.cpu() ++ cache_dyn = cache_dyn.cpu() ++ return logits, cache_dyn, cache_sta + + + class Whisper(nn.Module): +@@ -257,7 +283,8 @@ class Whisper(nn.Module): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): +- return self.decoder(tokens, audio_features) ++ pos_embed = self.decoder.positional_embedding[:tokens.shape[-1]] ++ return self.decoder(tokens, audio_features, pos_embed) + + def forward( + self, mel: torch.Tensor, tokens: torch.Tensor diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_aie.py b/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_aie.py new file mode 100644 index 0000000000000000000000000000000000000000..6f7cfdc2192a7f41c015e9259469e0101c726a5b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_aie.py @@ -0,0 +1,128 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time +import torch +import mindietorch + +_N_MEL = 80 +_FRAMES = 3000 +_HALF_FRAMES = 1500 +_MAX_TOKEN = 224 +_KV_NUM = 2 + + +def test(inputs, model, stream, meta=""): + # warmup + for _ in range(10): + with mindietorch.npu.stream(stream): + model(*inputs) + stream.synchronize() + + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + with mindietorch.npu.stream(stream): + model(*inputs) + stream.synchronize() + end = time.time() + + print(f"{meta} latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"{meta} throughput: {num_infer / (end - start):.2f} fps") + + +def test_encoder(args): + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + model = torch.jit.load(args.encoder_aie_path) + model.eval() + + inputs = [ + torch.ones((1, _N_MEL, _FRAMES), dtype=torch.float32).to(device) + ] + + test(inputs, model, stream, "Encoder") + + +def test_decoder_prefill(args): + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + model = torch.jit.load(args.decoder_prefill_aie_path) + model.eval() + + assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}' + + inputs = [ + torch.ones((args.beam_size, args.ntokens), dtype=torch.float32).to(device), + torch.ones((1, _HALF_FRAMES, args.hidden), dtype=torch.float32).to(device), + torch.ones((args.ntokens, args.hidden), dtype=torch.float32).to(device) + ] + + test(inputs, model, stream, "Decoder prefill") + + +def test_decoder_decode(args): + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + model = torch.jit.load(args.decoder_decode_aie_path) + model.eval() + + inputs = [ + torch.ones((args.beam_size, 1), dtype=torch.float32).to(device), + torch.ones((1, _HALF_FRAMES, args.hidden), dtype=torch.float32).to(device), + torch.ones((args.hidden), dtype=torch.float32).to(device), + torch.ones((args.nblocks, _KV_NUM, args.beam_size, args.ntokens, args.hidden), dtype=torch.float32).to(device), + torch.ones((args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden), dtype=torch.float32).to(device), + ] + + test(inputs, model, stream, "Decoder decode") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--encoder_aie_path", + type=str, default="/tmp/models/encoder_compiled.ts" + ) + parser.add_argument( + "--decoder_prefill_aie_path", + type=str, default="/tmp/models/decoder_prefill_compiled.ts" + ) + parser.add_argument( + "--decoder_decode_aie_path", + type=str, default="/tmp/models/decoder_decode_compiled.ts" + ) + parser.add_argument("--beam_size", type=int, default=5) + parser.add_argument("--ntokens", type=int, default=100) + parser.add_argument("--nblocks", type=int, default=4) + parser.add_argument("--hidden", type=int, default=384) + parser.add_argument("--device_id", type=int, help="NPU device id", default=0) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + mindietorch.set_device(args.device_id) + + for func in test_encoder, test_decoder_prefill, test_decoder_decode: + func(args) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_onnx.py b/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..bd893efd82fc5ac8987c7c790d37ea386cbf4845 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/perf_test_onnx.py @@ -0,0 +1,120 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time +import onnxruntime as ort +import numpy as np + +_N_MEL = 80 +_FRAMES = 3000 +_MAX_TOKEN = 224 +_HALF_FRAMES = 1500 +_KV_NUM = 2 + + +def test(encoder_path, provider, output_names, onnx_inputs, meta=""): + onnx_model = ort.InferenceSession( + encoder_path, + providers=[provider] + ) + + # warmup + for _ in range(10): + onnx_model.run(output_names, onnx_inputs) + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + onnx_model.run(output_names, onnx_inputs) + end = time.time() + + print(f"{meta} latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"{meta} throughput: {num_infer / (end - start):.2f} fps") + + +def test_encoder(args, provider): + x = np.ones((1, _N_MEL, _FRAMES), dtype=np.float16 if args.use_gpu else np.float32) + onnx_inputs = {'mel': ort.OrtValue.ortvalue_from_numpy(x)} + output_names = ['ret'] + + test(args.encoder_onnx_path, provider, output_names, onnx_inputs, "Encoder") + + +def test_decoder_prefill(args, provider): + assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}' + tokens = np.ones((args.beam_size, args.ntokens), dtype=np.int64) + audio_features = np.ones((1, _HALF_FRAMES, args.hidden), dtype=np.float16 if args.use_gpu else np.float32) + pos_embed = np.ones((args.ntokens, args.hidden), dtype=np.float32) + onnx_inputs = { + 'tokens': ort.OrtValue.ortvalue_from_numpy(tokens), + 'audio_features': ort.OrtValue.ortvalue_from_numpy(audio_features), + 'pos_embed': ort.OrtValue.ortvalue_from_numpy(pos_embed) + } + output_names = ["logits", "cache_dyn", "cache_sta"] + + test(args.decoder_prefill_onnx_path, provider, output_names, onnx_inputs, "Decoder prefill") + + +def test_decoder_decode(args, provider): + assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}' + tokens = np.ones((args.beam_size, 1), dtype=np.int64) + pos_embed = np.ones((args.hidden), dtype=np.float32) + cache_dyn = np.ones( + (args.nblocks, _KV_NUM, args.beam_size, args.ntokens, args.hidden), + dtype=np.float16 if args.use_gpu else np.float32 + ) + cache_sta = np.ones( + (args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden), + dtype=np.float16 if args.use_gpu else np.float32 + ) + onnx_inputs = { + 'tokens': ort.OrtValue.ortvalue_from_numpy(tokens), # audio_features onnx导出被折叠 + 'pos_embed': ort.OrtValue.ortvalue_from_numpy(pos_embed), + 'cache_dyn': ort.OrtValue.ortvalue_from_numpy(cache_dyn), + 'cache_sta': ort.OrtValue.ortvalue_from_numpy(cache_sta) + } + output_names = ["logits", "new_cache_dyn", "new_cache_sta"] + + test(args.decoder_decode_onnx_path, provider, output_names, onnx_inputs, "Decoder decode") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--encoder_onnx_path',type=str, default='/tmp/models/encoder.onnx') + parser.add_argument('--decoder_prefill_onnx_path',type=str, default='/tmp/models/decoder_prefill.onnx') + parser.add_argument('--decoder_decode_onnx_path',type=str, default='/tmp/models/decoder_decode.onnx') + parser.add_argument("--use_gpu", action="store_true") + parser.add_argument("--beam_size", type=int, default=5) + parser.add_argument("--ntokens", type=int, default=100) + parser.add_argument("--nblocks", type=int, default=4) + parser.add_argument("--hidden", type=int, default=384) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + if args.use_gpu: + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" + + for func in test_encoder, test_decoder_prefill, test_decoder_decode: + func(args, provider) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/precision_test.py b/MindIE/MindIE-Torch/built-in/audio/Whisper/precision_test.py new file mode 100644 index 0000000000000000000000000000000000000000..de9f03ae899e21317e8e05d310bfc071dd23a1bc --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/precision_test.py @@ -0,0 +1,191 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import torch +import torch.nn.functional as F +import onnxruntime as ort +import numpy as np +import mindietorch + +_N_MEL = 80 +_FRAMES = 3000 +_MAX_TOKEN = 224 +_HALF_FRAMES = 1500 +_KV_NUM = 2 + +def compare_onnx_aie_output(onnx_out, aie_out, sim_threshold=0.99): + num_sim = 0 + for i, (a, b) in enumerate(zip(onnx_out, aie_out)): + a = a.reshape(1, -1).astype(np.float32) + b = b.reshape(1, -1) + sim = F.cosine_similarity(torch.from_numpy(a), b, dim=1) + if sim > sim_threshold: + num_sim += 1 + else: + print(f'Output {i} similarity: {sim}') + + print(f'Number of outputs to compare: {len(onnx_out)}') + print(f'Number of outputs with cosine similarity > {sim_threshold}: {num_sim}') + + +def compare_encoder(args): + device = f'npu:{args.device_id}' + + onnx_model = ort.InferenceSession( + args.encoder_onnx_path, + providers=["CPUExecutionProvider"] + ) + + x = np.ones((1, _N_MEL, _FRAMES), dtype=np.float32) + onnx_inputs = {'mel': ort.OrtValue.ortvalue_from_numpy(x)} + output_names = ['ret'] + onnx_out = onnx_model.run(output_names, onnx_inputs) + + aie_inputs = [x] + for i in range(len(aie_inputs)): + aie_inputs[i] = torch.from_numpy(aie_inputs[i]).to(device) + + mindietorch.set_device(args.device_id) + stream = mindietorch.npu.Stream(device) + model = torch.jit.load(args.encoder_aie_path) + model.eval().to(device) + + with mindietorch.npu.stream(stream): + aie_out = model(*aie_inputs) + stream.synchronize() + + if isinstance(aie_out, tuple): + aie_out = (x.cpu() for x in aie_out) + else: + aie_out = aie_out.cpu() + compare_onnx_aie_output(onnx_out, aie_out, args.sim_threshold) + + +def compare_decoder_prefill(args): + device = f'npu:{args.device_id}' + + onnx_model = ort.InferenceSession( + args.decoder_prefill_onnx_path, + providers=["CPUExecutionProvider"] + ) + + assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}' + tokens = np.ones((args.beam_size, args.ntokens), dtype=np.int64) + audio_features = np.ones((1, _HALF_FRAMES, args.hidden), dtype=np.float32) + pos_embed = np.ones((args.ntokens, args.hidden), dtype=np.float32) + onnx_inputs = { + 'tokens': ort.OrtValue.ortvalue_from_numpy(tokens), + 'audio_features': ort.OrtValue.ortvalue_from_numpy(audio_features), + 'pos_embed': ort.OrtValue.ortvalue_from_numpy(pos_embed) + } + output_names = ["logits", "cache_dyn", "cache_sta"] + onnx_out = onnx_model.run(output_names, onnx_inputs) + + aie_inputs = [tokens.astype(np.float32), audio_features, pos_embed] + for i in range(len(aie_inputs)): + aie_inputs[i] = torch.from_numpy(aie_inputs[i]).to(device) + + mindietorch.set_device(args.device_id) + stream = mindietorch.npu.Stream(device) + model = torch.jit.load(args.decoder_prefill_aie_path) + model.eval().to(device) + + with mindietorch.npu.stream(stream): + aie_out = model(*aie_inputs) + stream.synchronize() + if isinstance(aie_out, tuple): + aie_out = (x.cpu() for x in aie_out) + else: + aie_out = aie_out.cpu() + compare_onnx_aie_output(onnx_out, aie_out, args.sim_threshold) + + +def compare_decoder_decode(args): + device = f'npu:{args.device_id}' + + onnx_model = ort.InferenceSession( + args.decoder_decode_onnx_path, + providers=["CPUExecutionProvider"] + ) + + assert args.ntokens <= _MAX_TOKEN, f'ntokens can not exceed {_MAX_TOKEN}' + tokens = np.ones((args.beam_size, 1), dtype=np.int64) + audio_features = np.ones((1, _HALF_FRAMES, args.hidden), dtype=np.float32) + pos_embed = np.ones((args.hidden), dtype=np.float32) + cache_dyn = np.ones((args.nblocks, _KV_NUM, args.beam_size, args.ntokens, args.hidden), dtype=np.float32) + cache_sta = np.ones((args.nblocks, _KV_NUM, 1, _HALF_FRAMES, args.hidden), dtype=np.float32) + onnx_inputs = { + 'tokens': ort.OrtValue.ortvalue_from_numpy(tokens), # audio_features onnx导出被折叠 + 'pos_embed': ort.OrtValue.ortvalue_from_numpy(pos_embed), + 'cache_dyn': ort.OrtValue.ortvalue_from_numpy(cache_dyn), + 'cache_sta': ort.OrtValue.ortvalue_from_numpy(cache_sta) + } + + output_names = ["logits", "new_cache_dyn", "new_cache_sta"] + onnx_out = onnx_model.run(output_names, onnx_inputs) + + aie_inputs = [tokens.astype(np.float32), audio_features, pos_embed, cache_dyn, cache_sta] + for i in range(len(aie_inputs)): + aie_inputs[i] = torch.from_numpy(aie_inputs[i]).to(device) + + mindietorch.set_device(args.device_id) + stream = mindietorch.npu.Stream(device) + model = torch.jit.load(args.decoder_decode_aie_path) + model.eval().to(device) + + with mindietorch.npu.stream(stream): + aie_out = model(*aie_inputs) + stream.synchronize() + if isinstance(aie_out, tuple): + aie_out = (x.cpu() for x in aie_out) + else: + aie_out = aie_out.cpu() + compare_onnx_aie_output(onnx_out, aie_out, args.sim_threshold) + + +def parse_args(): + parser = argparse.ArgumentParser() + # encoder + parser.add_argument('--encoder_onnx_path',type=str, default='/tmp/models/encoder.onnx') + parser.add_argument('--encoder_aie_path', type=str, default='/tmp/models/encoder_compiled.ts') + # decoder_prefill + parser.add_argument('--decoder_prefill_onnx_path',type=str, default='/tmp/models/decoder_prefill.onnx') + parser.add_argument('--decoder_prefill_aie_path', type=str, default='/tmp/models/decoder_prefill_compiled.ts') + # decoder_decode + parser.add_argument('--decoder_decode_onnx_path',type=str, default='/tmp/models/decoder_decode.onnx') + parser.add_argument('--decoder_decode_aie_path', type=str, default='/tmp/models/decoder_decode_compiled.ts') + parser.add_argument('--sim_threshold', type=float, default=0.99) + parser.add_argument('--device_id', type=int, default=0) + parser.add_argument("--beam_size", type=int, default=5) + parser.add_argument("--ntokens", type=int, default=100) + parser.add_argument("--nblocks", type=int, default=4) + parser.add_argument("--hidden", type=int, default=384) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + print('=== Compare the outputs of ONNX and AIE ===') + + print('Start comparing encoder...') + funcs = [compare_encoder, compare_decoder_prefill, compare_decoder_decode] + for func in funcs: + func(args) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/audio/Whisper/trace_model.patch b/MindIE/MindIE-Torch/built-in/audio/Whisper/trace_model.patch new file mode 100644 index 0000000000000000000000000000000000000000..a35756ff38f412308096baa0caf2da2880b71613 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Whisper/trace_model.patch @@ -0,0 +1,343 @@ +diff --git a/whisper/decoding.py b/whisper/decoding.py +index 49485d0..495fe45 100644 +--- a/whisper/decoding.py ++++ b/whisper/decoding.py +@@ -2,6 +2,7 @@ from dataclasses import dataclass, field, replace + from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union + + import numpy as np ++import os + import torch + import torch.nn.functional as F + from torch import Tensor +@@ -49,12 +50,24 @@ def detect_language( + + # skip encoder forward pass if already-encoded audio features were given + if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): ++ encoder_ts_model = torch.jit.trace(model.encoder, mel) ++ encoder_ts_model.save( ++ "/tmp/models/encoder.ts") ++ torch.onnx.export( ++ model.encoder, ++ (mel), ++ "/tmp/models/encoder.onnx", ++ opset_version=11, ++ input_names=["mel"], ++ output_names=["ret"] ++ ) ++ + mel = model.encoder(mel) + + # forward pass using a single token, startoftranscript + n_audio = mel.shape[0] + x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] +- logits = model.logits(x, mel)[:, 0] ++ logits = model.logits(x, mel)[0][:, 0] + + # collect detected languages; suppress all non-language tokens + mask = torch.ones(logits.shape[-1], dtype=torch.bool) +@@ -145,36 +158,74 @@ class PyTorchInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length +- self.kv_cache = {} +- self.hooks = [] +- +- key_modules = [block.attn.key for block in self.model.decoder.blocks] +- value_modules = [block.attn.value for block in self.model.decoder.blocks] +- self.kv_modules = key_modules + value_modules ++ self.cache_dyn = None ++ self.cache_sta = None + + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: +- if not self.kv_cache: +- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() +- + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] ++ pos_embed = self.model.decoder.positional_embedding[self.cache_dyn.shape[3]] ++ torch.onnx.export( ++ self.model.decoder, ++ (tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta), ++ "/tmp/models/decoder_decode.onnx", ++ opset_version=11, ++ input_names=["tokens", "audio_features", "pos_embed", "cache_dyn", "cache_sta"], ++ output_names=["logits", "new_cache_dyn", "new_cache_sta"], ++ dynamic_axes={ ++ "cache_dyn": {3: "ntokens"}, ++ "new_cache_dyn": {3: "ntokens"} ++ } ++ ) ++ decoder_decode_ts_model = torch.jit.trace( ++ self.model.decoder, ++ (tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta) ++ ) ++ decoder_decode_ts_model.save( ++ "/tmp/models/decoder_decode.ts") ++ logits, cache_dyn, _ = self.model.decoder( ++ tokens, audio_features, pos_embed, self.cache_dyn, self.cache_sta) ++ os.sys.exit(0) ++ self.cache_dyn = cache_dyn ++ else: ++ pos_embed = self.model.decoder.positional_embedding[:tokens.shape[-1]] ++ torch.onnx.export( ++ self.model.decoder, ++ (tokens, audio_features, pos_embed), ++ "/tmp/models/decoder_prefill.onnx", ++ opset_version=11, ++ input_names=["tokens", "audio_features", "pos_embed"], ++ output_names=["logits", "cache_dyn", "cache_sta"], ++ dynamic_axes={ ++ "tokens": {1: "ntokens"}, ++ "pos_embed": {0: "ntokens"}, ++ "logits": {1: "ntokens"}, ++ "cache_dyn": {3: "ntokens"} ++ } ++ ) ++ decoder_prefill_ts_model = torch.jit.trace( ++ self.model.decoder, ++ (tokens, audio_features, pos_embed) ++ ) ++ decoder_prefill_ts_model.save( ++ "/tmp/models/decoder_prefill.ts") ++ logits, cache_dyn, cache_sta = self.model.decoder(tokens, audio_features, pos_embed) ++ self.cache_dyn = cache_dyn ++ self.cache_sta = cache_sta + +- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) ++ return logits + + def cleanup_caching(self): +- for hook in self.hooks: +- hook.remove() +- +- self.kv_cache = {} +- self.hooks = [] ++ self.cache_dyn = None ++ self.cache_sta = None + + def rearrange_kv_cache(self, source_indices): + if source_indices != list(range(len(source_indices))): +- for module in self.kv_modules: +- # update the key/value cache to contain the selected sequences +- self.kv_cache[module] = self.kv_cache[module][source_indices].detach() +- ++ blocks = self.cache_dyn.shape[0] ++ for i in range(blocks): ++ for j in range(2): # k and v 2 items ++ self.cache_dyn[i][j] = self.cache_dyn[i][j][source_indices] + + class SequenceRanker: + def rank( +diff --git a/whisper/model.py b/whisper/model.py +index a678283..2a95e28 100644 +--- a/whisper/model.py ++++ b/whisper/model.py +@@ -1,6 +1,7 @@ + import base64 + import gzip + from dataclasses import dataclass ++import os + from typing import Dict, Iterable, Optional + + import numpy as np +@@ -68,6 +69,63 @@ class MultiHeadAttention(nn.Module): + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + ++ def encoder_forward(self, x: Tensor): ++ q = self.query(x) ++ k = self.key(x) ++ v = self.value(x) ++ wv, qk = self.qkv_attention(q, k, v) ++ return self.out(wv) ++ ++ def prefill_self_attn_forward( ++ self, ++ x: Tensor, ++ mask: Tensor, ++ ): ++ q = self.query(x) ++ k = self.key(x) ++ v = self.value(x) ++ cache_dyn = torch.stack([k, v]) ++ wv, _ = self.qkv_attention(q, k, v, mask) ++ return self.out(wv), cache_dyn ++ ++ def prefill_cross_attn_forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ ): ++ q = self.query(x) ++ k = self.key(xa) ++ v = self.value(xa) ++ cache_sta = torch.stack([k, v]) ++ wv, _ = self.qkv_attention(q, k, v) ++ return self.out(wv), cache_sta ++ ++ def decode_self_attn_forward( ++ self, ++ x: Tensor, ++ mask: Tensor, ++ cache_dyn: Tensor ++ ): ++ q = self.query(x) ++ token_k = self.key(x) ++ k = torch.cat([cache_dyn[0], token_k], dim=1).detach() ++ token_v = self.value(x) ++ v = torch.cat([cache_dyn[1], token_v], dim=1).detach() ++ new_cache_dyn = torch.stack([k, v]) ++ wv, _ = self.qkv_attention(q, k, v, mask) ++ return self.out(wv), new_cache_dyn ++ ++ def decode_cross_attn_forward( ++ self, ++ x: Tensor, ++ cache_sta: Tensor ++ ): ++ q = self.query(x) ++ k = cache_sta[0] ++ v = cache_sta[1] ++ wv, _ = self.qkv_attention(q, k, v) ++ return self.out(wv) ++ + def forward( + self, + x: Tensor, +@@ -126,6 +184,39 @@ class ResidualAttentionBlock(nn.Module): + ) + self.mlp_ln = LayerNorm(n_state) + ++ def encoder_forward(self, x: Tensor): ++ x = x + self.attn.encoder_forward(self.attn_ln(x)) ++ x = x + self.mlp(self.mlp_ln(x)) ++ return x ++ ++ def prefill_forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ mask: Tensor, ++ ): ++ self_attn_out, new_cache_dyn = self.attn.prefill_self_attn_forward(self.attn_ln(x), mask) ++ x = x + self_attn_out ++ cross_attn_out, new_cache_sta = self.cross_attn.prefill_cross_attn_forward(self.cross_attn_ln(x), xa) ++ x = x + cross_attn_out ++ x = x + self.mlp(self.mlp_ln(x)) ++ return x, new_cache_dyn, new_cache_sta ++ ++ def decode_forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ mask: Tensor, ++ cache_dyn: Tensor, ++ cache_sta: Tensor ++ ): ++ self_attn_out, new_cache_dyn = self.attn.decode_self_attn_forward(self.attn_ln(x), mask, cache_dyn) ++ x = x + self_attn_out ++ cross_attn_out = self.cross_attn.decode_cross_attn_forward(self.cross_attn_ln(x), cache_sta) ++ x = x + cross_attn_out ++ x = x + self.mlp(self.mlp_ln(x)) ++ return x, new_cache_dyn ++ + def forward( + self, + x: Tensor, +@@ -163,11 +254,10 @@ class AudioEncoder(nn.Module): + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + +- assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + x = (x + self.positional_embedding).to(x.dtype) + + for block in self.blocks: +- x = block(x) ++ x = block.encoder_forward(x) + + x = self.ln_post(x) + return x +@@ -193,29 +283,56 @@ class TextDecoder(nn.Module): + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + +- def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): +- """ +- x : torch.LongTensor, shape = (batch_size, <= n_ctx) +- the text tokens +- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) +- the encoded audio features to be attended on +- """ +- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 +- x = ( +- self.token_embedding(x) +- + self.positional_embedding[offset : offset + x.shape[-1]] +- ) +- x = x.to(xa.dtype) ++ def prefill(self, x: Tensor, xa: Tensor, pos_embed: Tensor): ++ x = (self.token_embedding(x) + pos_embed).to(xa.dtype) + ++ cache_dyn_list = [] ++ cache_sta_list = [] + for block in self.blocks: +- x = block(x, xa, mask=self.mask, kv_cache=kv_cache) ++ x, new_cache_dyn, new_cache_sta = block.prefill_forward(x, xa, self.mask) ++ cache_dyn_list.append(new_cache_dyn) ++ cache_sta_list.append(new_cache_sta) ++ ++ cache_dyn = torch.stack(cache_dyn_list) ++ cache_sta = torch.stack(cache_sta_list) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + +- return logits ++ return logits, cache_dyn, cache_sta ++ ++ def decode(self, x: Tensor, xa: Tensor, pos_embed: Tensor, cache_dyn: Tensor, cache_sta: Tensor): ++ x = (self.token_embedding(x) + pos_embed).to(xa.dtype) ++ ++ cache_dyn_list = [] ++ for idx, block in enumerate(self.blocks): ++ x, new_cache_dyn = block.decode_forward(x, xa, self.mask, cache_dyn[idx], cache_sta[idx]) ++ cache_dyn_list.append(new_cache_dyn) ++ ++ new_cache_dyn = torch.stack(cache_dyn_list) ++ ++ x = self.ln(x) ++ logits = ( ++ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) ++ ).float() ++ ++ return logits, new_cache_dyn ++ ++ def forward( ++ self, ++ x: Tensor, ++ xa: Tensor, ++ pos_embed: Tensor = None, ++ cache_dyn: Tensor = None, ++ cache_sta: Tensor = None, ++ ): ++ if cache_dyn is None: ++ logits, cache_dyn, cache_sta = self.prefill(x, xa, pos_embed) ++ else: ++ logits, cache_dyn = self.decode(x, xa, pos_embed, cache_dyn, cache_sta) ++ return logits, cache_dyn, cache_sta + + + class Whisper(nn.Module): +@@ -257,7 +374,8 @@ class Whisper(nn.Module): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): +- return self.decoder(tokens, audio_features) ++ pos_embed = self.decoder.positional_embedding[:tokens.shape[-1]] ++ return self.decoder(tokens, audio_features, pos_embed) + + def forward( + self, mel: torch.Tensor, tokens: torch.Tensor